fixes
This commit is contained in:
parent
cf078c330c
commit
5bc0c3b2b4
@ -1,19 +1,17 @@
|
|||||||
package dnsMitm
|
package dnsMitmProxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
type DNSMITM struct {
|
type DNSMITMProxy struct {
|
||||||
ListenPort uint16
|
|
||||||
TargetDNSServerAddress string
|
TargetDNSServerAddress string
|
||||||
TargetDNSServerPort uint16
|
TargetDNSServerPort uint16
|
||||||
|
|
||||||
@ -21,7 +19,7 @@ type DNSMITM struct {
|
|||||||
ResponseHook func(net.Addr, dns.Msg, dns.Msg, string) (*dns.Msg, error)
|
ResponseHook func(net.Addr, dns.Msg, dns.Msg, string) (*dns.Msg, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p DNSMITM) requestDNS(req []byte, network string) ([]byte, error) {
|
func (p DNSMITMProxy) requestDNS(req []byte, network string) ([]byte, error) {
|
||||||
serverConn, err := net.Dial(network, fmt.Sprintf("%s:%d", p.TargetDNSServerAddress, p.TargetDNSServerPort))
|
serverConn, err := net.Dial(network, fmt.Sprintf("%s:%d", p.TargetDNSServerAddress, p.TargetDNSServerPort))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to dial DNS server: %w", err)
|
return nil, fmt.Errorf("failed to dial DNS server: %w", err)
|
||||||
@ -65,7 +63,7 @@ func (p DNSMITM) requestDNS(req []byte, network string) ([]byte, error) {
|
|||||||
return resp[:n], nil
|
return resp[:n], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p DNSMITM) processReq(clientAddr net.Addr, req []byte, network string) ([]byte, error) {
|
func (p DNSMITMProxy) processReq(clientAddr net.Addr, req []byte, network string) ([]byte, error) {
|
||||||
var reqMsg dns.Msg
|
var reqMsg dns.Msg
|
||||||
if p.RequestHook != nil || p.ResponseHook != nil {
|
if p.RequestHook != nil || p.ResponseHook != nil {
|
||||||
err := reqMsg.Unpack(req)
|
err := reqMsg.Unpack(req)
|
||||||
@ -97,14 +95,14 @@ func (p DNSMITM) processReq(clientAddr net.Addr, req []byte, network string) ([]
|
|||||||
|
|
||||||
resp, err := p.requestDNS(req, network)
|
resp, err := p.requestDNS(req, network)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to send request")
|
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.ResponseHook != nil {
|
if p.ResponseHook != nil {
|
||||||
var respMsg dns.Msg
|
var respMsg dns.Msg
|
||||||
err = respMsg.Unpack(resp)
|
err = respMsg.Unpack(resp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to parse response")
|
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
modifiedResp, err := p.ResponseHook(clientAddr, reqMsg, respMsg, network)
|
modifiedResp, err := p.ResponseHook(clientAddr, reqMsg, respMsg, network)
|
||||||
@ -123,12 +121,7 @@ func (p DNSMITM) processReq(clientAddr net.Addr, req []byte, network string) ([]
|
|||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p DNSMITM) ListenTCP(ctx context.Context) error {
|
func (p DNSMITMProxy) ListenTCP(ctx context.Context, addr *net.TCPAddr) error {
|
||||||
addr, err := net.ResolveTCPAddr("tcp", "[::]:"+strconv.Itoa(int(p.ListenPort)))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to resolve tcp address: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
listener, err := net.ListenTCP("tcp", addr)
|
listener, err := net.ListenTCP("tcp", addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to listen tcp port: %v", err)
|
return fmt.Errorf("failed to listen tcp port: %v", err)
|
||||||
@ -184,12 +177,7 @@ func (p DNSMITM) ListenTCP(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p DNSMITM) ListenUDP(ctx context.Context) error {
|
func (p DNSMITMProxy) ListenUDP(ctx context.Context, addr *net.UDPAddr) error {
|
||||||
addr, err := net.ResolveUDPAddr("udp", "[::]:"+strconv.Itoa(int(p.ListenPort)))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to resolve udp address: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := net.ListenUDP("udp", addr)
|
conn, err := net.ListenUDP("udp", addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to listen udp port: %v", err)
|
return fmt.Errorf("failed to listen udp port: %v", err)
|
||||||
@ -226,14 +214,8 @@ func (p DNSMITM) ListenUDP(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(listenPort uint16, targetDNSServerAddress string, targetDNSServerPort ...uint16) *DNSMITM {
|
func New() *DNSMITMProxy {
|
||||||
dnsMitm := &DNSMITM{
|
return &DNSMITMProxy{
|
||||||
ListenPort: listenPort,
|
TargetDNSServerPort: 53,
|
||||||
TargetDNSServerAddress: targetDNSServerAddress,
|
|
||||||
TargetDNSServerPort: 53,
|
|
||||||
}
|
}
|
||||||
if len(targetDNSServerPort) > 0 {
|
|
||||||
dnsMitm.TargetDNSServerPort = targetDNSServerPort[0]
|
|
||||||
}
|
|
||||||
return dnsMitm
|
|
||||||
}
|
}
|
145
kvas2.go
145
kvas2.go
@ -6,19 +6,20 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"kvas2-go/dns-mitm"
|
"kvas2-go/dns-mitm-proxy"
|
||||||
"kvas2-go/models"
|
"kvas2-go/models"
|
||||||
"kvas2-go/netfilter-helper"
|
"kvas2-go/netfilter-helper"
|
||||||
|
"kvas2-go/records"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/vishvananda/netlink"
|
"github.com/vishvananda/netlink"
|
||||||
"github.com/vishvananda/netlink/nl"
|
"github.com/vishvananda/netlink/nl"
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -39,10 +40,10 @@ type Config struct {
|
|||||||
type App struct {
|
type App struct {
|
||||||
Config Config
|
Config Config
|
||||||
|
|
||||||
DNSMITM *dnsMitm.DNSMITM
|
DNSMITM *dnsMitmProxy.DNSMITMProxy
|
||||||
NetfilterHelper4 *netfilterHelper.NetfilterHelper
|
NetfilterHelper4 *netfilterHelper.NetfilterHelper
|
||||||
NetfilterHelper6 *netfilterHelper.NetfilterHelper
|
NetfilterHelper6 *netfilterHelper.NetfilterHelper
|
||||||
Records *Records
|
Records *records.Records
|
||||||
Groups map[uuid.UUID]*Group
|
Groups map[uuid.UUID]*Group
|
||||||
|
|
||||||
Link netlink.Link
|
Link netlink.Link
|
||||||
@ -54,7 +55,7 @@ type App struct {
|
|||||||
|
|
||||||
func (a *App) handleLink(event netlink.LinkUpdate) {
|
func (a *App) handleLink(event netlink.LinkUpdate) {
|
||||||
switch event.Change {
|
switch event.Change {
|
||||||
case unix.IFF_UP:
|
case 0x00000001:
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Str("interface", event.Link.Attrs().Name).
|
Str("interface", event.Link.Attrs().Name).
|
||||||
Str("operstatestr", event.Attrs().OperState.String()).
|
Str("operstatestr", event.Attrs().OperState.String()).
|
||||||
@ -94,7 +95,6 @@ func (a *App) start(ctx context.Context) (err error) {
|
|||||||
newCtx, cancel := context.WithCancel(ctx)
|
newCtx, cancel := context.WithCancel(ctx)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// TODO: Chan err
|
|
||||||
errChan := make(chan error)
|
errChan := make(chan error)
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -102,16 +102,28 @@ func (a *App) start(ctx context.Context) (err error) {
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
err := a.DNSMITM.ListenUDP(newCtx)
|
addr, err := net.ResolveUDPAddr("udp", "[::]:"+strconv.Itoa(int(a.Config.ListenDNSPort)))
|
||||||
|
if err != nil {
|
||||||
|
errChan <- fmt.Errorf("failed to resolve udp address: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = a.DNSMITM.ListenUDP(newCtx, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errChan <- fmt.Errorf("failed to serve DNS UDP proxy: %v", err)
|
errChan <- fmt.Errorf("failed to serve DNS UDP proxy: %v", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
err := a.DNSMITM.ListenTCP(newCtx)
|
addr, err := net.ResolveTCPAddr("tcp", "[::]:"+strconv.Itoa(int(a.Config.ListenDNSPort)))
|
||||||
|
if err != nil {
|
||||||
|
errChan <- fmt.Errorf("failed to resolve tcp address: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = a.DNSMITM.ListenTCP(newCtx, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errChan <- fmt.Errorf("failed to serve DNS TCP proxy: %v", err)
|
errChan <- fmt.Errorf("failed to serve DNS TCP proxy: %v", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@ -125,20 +137,14 @@ func (a *App) start(ctx context.Context) (err error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to override DNS (IPv4): %v", err)
|
return fmt.Errorf("failed to override DNS (IPv4): %v", err)
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() { _ = a.dnsOverrider4.Disable() }()
|
||||||
// TODO: Handle error
|
|
||||||
_ = a.dnsOverrider4.Disable()
|
|
||||||
}()
|
|
||||||
|
|
||||||
a.dnsOverrider6 = a.NetfilterHelper6.PortRemap(fmt.Sprintf("%sDNSOR", a.Config.ChainPrefix), 53, a.Config.ListenDNSPort, addrList)
|
a.dnsOverrider6 = a.NetfilterHelper6.PortRemap(fmt.Sprintf("%sDNSOR", a.Config.ChainPrefix), 53, a.Config.ListenDNSPort, addrList)
|
||||||
err = a.dnsOverrider6.Enable()
|
err = a.dnsOverrider6.Enable()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to override DNS (IPv6): %v", err)
|
return fmt.Errorf("failed to override DNS (IPv6): %v", err)
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() { _ = a.dnsOverrider6.Disable() }()
|
||||||
// TODO: Handle error
|
|
||||||
_ = a.dnsOverrider6.Disable()
|
|
||||||
}()
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
Groups
|
Groups
|
||||||
@ -152,7 +158,6 @@ func (a *App) start(ctx context.Context) (err error) {
|
|||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
for _, group := range a.Groups {
|
for _, group := range a.Groups {
|
||||||
// TODO: Handle error
|
|
||||||
_ = group.Disable()
|
_ = group.Disable()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@ -170,13 +175,16 @@ func (a *App) start(ctx context.Context) (err error) {
|
|||||||
return fmt.Errorf("error while serve UNIX socket: %v", err)
|
return fmt.Errorf("error while serve UNIX socket: %v", err)
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
// TODO: Handle error
|
|
||||||
_ = socket.Close()
|
_ = socket.Close()
|
||||||
_ = os.Remove(socketPath)
|
_ = os.Remove(socketPath)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
|
if newCtx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
conn, err := socket.Accept()
|
conn, err := socket.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !strings.Contains(err.Error(), "use of closed network connection") {
|
if !strings.Contains(err.Error(), "use of closed network connection") {
|
||||||
@ -186,10 +194,7 @@ func (a *App) start(ctx context.Context) (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
go func(conn net.Conn) {
|
go func(conn net.Conn) {
|
||||||
defer func() {
|
defer func() { _ = conn.Close() }()
|
||||||
// TODO: Handle error
|
|
||||||
_ = conn.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
buf := make([]byte, 1024)
|
buf := make([]byte, 1024)
|
||||||
n, err := conn.Read(buf)
|
n, err := conn.Read(buf)
|
||||||
@ -206,6 +211,12 @@ func (a *App) start(ctx context.Context) (err error) {
|
|||||||
log.Error().Err(err).Msg("error while fixing iptables after netfilter.d")
|
log.Error().Err(err).Msg("error while fixing iptables after netfilter.d")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if a.dnsOverrider6.Enabled {
|
||||||
|
err = a.dnsOverrider6.PutIPTable(args[2])
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("error while fixing iptables after netfilter.d")
|
||||||
|
}
|
||||||
|
}
|
||||||
for _, group := range a.Groups {
|
for _, group := range a.Groups {
|
||||||
if group.ifaceToIPSetNAT.Enabled {
|
if group.ifaceToIPSetNAT.Enabled {
|
||||||
err := group.ifaceToIPSetNAT.PutIPTable(args[2])
|
err := group.ifaceToIPSetNAT.PutIPTable(args[2])
|
||||||
@ -240,7 +251,6 @@ func (a *App) start(ctx context.Context) (err error) {
|
|||||||
case event := <-linkUpdateChannel:
|
case event := <-linkUpdateChannel:
|
||||||
a.handleLink(event)
|
a.handleLink(event)
|
||||||
case err := <-errChan:
|
case err := <-errChan:
|
||||||
close(errChan)
|
|
||||||
return err
|
return err
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return nil
|
return nil
|
||||||
@ -288,22 +298,17 @@ func (a *App) AddGroup(group *models.Group) error {
|
|||||||
Group: group,
|
Group: group,
|
||||||
iptables: a.NetfilterHelper4.IPTables,
|
iptables: a.NetfilterHelper4.IPTables,
|
||||||
ipset: ipset,
|
ipset: ipset,
|
||||||
ifaceToIPSetNAT: a.NetfilterHelper4.IfaceToIPSet(fmt.Sprintf("%sR_%d", a.Config.ChainPrefix, group.ID), group.Interface, ipsetName, false),
|
ifaceToIPSetNAT: a.NetfilterHelper4.IfaceToIPSet(fmt.Sprintf("%sR_%8x", a.Config.ChainPrefix, group.ID.ID()), group.Interface, ipsetName, false),
|
||||||
}
|
}
|
||||||
|
grp.ifaceToIPSetNAT.SoftwareMode = a.Config.UseSoftwareRouting
|
||||||
a.Groups[grp.ID] = grp
|
a.Groups[grp.ID] = grp
|
||||||
return a.SyncGroup(grp)
|
return a.SyncGroup(grp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *App) SyncGroup(group *Group) error {
|
func (a *App) SyncGroup(group *Group) error {
|
||||||
processedDomains := make(map[string]struct{})
|
|
||||||
newIpsetAddressesMap := make(map[string]time.Duration)
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
oldIpsetAddresses, err := group.ListIPv4()
|
addresses := make(map[string]time.Duration)
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to get old ipset list: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
knownDomains := a.Records.ListKnownDomains()
|
knownDomains := a.Records.ListKnownDomains()
|
||||||
for _, domain := range group.Rules {
|
for _, domain := range group.Rules {
|
||||||
if !domain.IsEnabled() {
|
if !domain.IsEnabled() {
|
||||||
@ -315,26 +320,24 @@ func (a *App) SyncGroup(group *Group) error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
cnames := a.Records.GetCNameRecords(domainName, true)
|
domainAddresses := a.Records.GetARecords(domainName)
|
||||||
if len(cnames) == 0 {
|
for _, address := range domainAddresses {
|
||||||
continue
|
|
||||||
}
|
|
||||||
for _, cname := range cnames {
|
|
||||||
processedDomains[cname] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
addresses := a.Records.GetARecords(domainName)
|
|
||||||
for _, address := range addresses {
|
|
||||||
ttl := now.Sub(address.Deadline)
|
ttl := now.Sub(address.Deadline)
|
||||||
if oldTTL, ok := newIpsetAddressesMap[string(address.Address)]; !ok || ttl > oldTTL {
|
if oldTTL, ok := addresses[string(address.Address)]; !ok || ttl > oldTTL {
|
||||||
newIpsetAddressesMap[string(address.Address)] = ttl
|
addresses[string(address.Address)] = ttl
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for addr, ttl := range newIpsetAddressesMap {
|
currentAddresses, err := group.ListIPv4()
|
||||||
if _, exists := oldIpsetAddresses[addr]; exists {
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get old ipset list: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for addr, ttl := range addresses {
|
||||||
|
// TODO: Check TTL
|
||||||
|
if _, exists := currentAddresses[addr]; exists {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
ip := net.IP(addr)
|
ip := net.IP(addr)
|
||||||
@ -344,11 +347,16 @@ func (a *App) SyncGroup(group *Group) error {
|
|||||||
Str("address", ip.String()).
|
Str("address", ip.String()).
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("failed to add address")
|
Msg("failed to add address")
|
||||||
|
} else {
|
||||||
|
log.Trace().
|
||||||
|
Str("address", ip.String()).
|
||||||
|
Err(err).
|
||||||
|
Msg("add address")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for addr := range oldIpsetAddresses {
|
for addr := range currentAddresses {
|
||||||
if _, exists := newIpsetAddressesMap[addr]; exists {
|
if _, ok := addresses[addr]; ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
ip := net.IP(addr)
|
ip := net.IP(addr)
|
||||||
@ -362,7 +370,7 @@ func (a *App) SyncGroup(group *Group) error {
|
|||||||
log.Trace().
|
log.Trace().
|
||||||
Str("address", ip.String()).
|
Str("address", ip.String()).
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("add address")
|
Msg("del address")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -400,9 +408,9 @@ func (a *App) processARecord(aRecord dns.A) {
|
|||||||
ttlDuration = a.Config.MinimalTTL
|
ttlDuration = a.Config.MinimalTTL
|
||||||
}
|
}
|
||||||
|
|
||||||
a.Records.AddARecord(aRecord.Hdr.Name, aRecord.A, ttlDuration)
|
a.Records.AddARecord(aRecord.Hdr.Name[:len(aRecord.Hdr.Name)-1], aRecord.A, ttlDuration)
|
||||||
|
|
||||||
names := a.Records.GetCNameRecords(aRecord.Hdr.Name, true)
|
names := a.Records.GetAliases(aRecord.Hdr.Name[:len(aRecord.Hdr.Name)-1])
|
||||||
for _, group := range a.Groups {
|
for _, group := range a.Groups {
|
||||||
Rule:
|
Rule:
|
||||||
for _, domain := range group.Rules {
|
for _, domain := range group.Rules {
|
||||||
@ -413,6 +421,7 @@ func (a *App) processARecord(aRecord dns.A) {
|
|||||||
if !domain.IsMatch(name) {
|
if !domain.IsMatch(name) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
// TODO: Check already existed
|
||||||
err := group.AddIPv4(aRecord.A, ttlDuration)
|
err := group.AddIPv4(aRecord.A, ttlDuration)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
@ -445,12 +454,12 @@ func (a *App) processCNameRecord(cNameRecord dns.CNAME) {
|
|||||||
ttlDuration = a.Config.MinimalTTL
|
ttlDuration = a.Config.MinimalTTL
|
||||||
}
|
}
|
||||||
|
|
||||||
a.Records.AddCNameRecord(cNameRecord.Hdr.Name, cNameRecord.Target, ttlDuration)
|
a.Records.AddCNameRecord(cNameRecord.Hdr.Name[:len(cNameRecord.Hdr.Name)-1], cNameRecord.Target, ttlDuration)
|
||||||
|
|
||||||
// TODO: Optimization
|
// TODO: Optimization
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
aRecords := a.Records.GetARecords(cNameRecord.Hdr.Name)
|
aRecords := a.Records.GetARecords(cNameRecord.Hdr.Name[:len(cNameRecord.Hdr.Name)-1])
|
||||||
names := a.Records.GetCNameRecords(cNameRecord.Hdr.Name, true)
|
names := a.Records.GetAliases(cNameRecord.Hdr.Name[:len(cNameRecord.Hdr.Name)-1])
|
||||||
for _, group := range a.Groups {
|
for _, group := range a.Groups {
|
||||||
Rule:
|
Rule:
|
||||||
for _, domain := range group.Rules {
|
for _, domain := range group.Rules {
|
||||||
@ -496,12 +505,6 @@ func (a *App) handleMessage(msg dns.Msg) {
|
|||||||
for _, rr := range msg.Answer {
|
for _, rr := range msg.Answer {
|
||||||
a.handleRecord(rr)
|
a.handleRecord(rr)
|
||||||
}
|
}
|
||||||
for _, rr := range msg.Ns {
|
|
||||||
a.handleRecord(rr)
|
|
||||||
}
|
|
||||||
for _, rr := range msg.Extra {
|
|
||||||
a.handleRecord(rr)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(config Config) (*App, error) {
|
func New(config Config) (*App, error) {
|
||||||
@ -511,14 +514,10 @@ func New(config Config) (*App, error) {
|
|||||||
|
|
||||||
app.Config = config
|
app.Config = config
|
||||||
|
|
||||||
app.DNSMITM = dnsMitm.New(app.Config.ListenDNSPort, app.Config.TargetDNSServerAddress)
|
app.DNSMITM = dnsMitmProxy.New()
|
||||||
|
app.DNSMITM.TargetDNSServerAddress = app.Config.TargetDNSServerAddress
|
||||||
|
app.DNSMITM.TargetDNSServerPort = 53
|
||||||
app.DNSMITM.RequestHook = func(clientAddr net.Addr, reqMsg dns.Msg, network string) (*dns.Msg, *dns.Msg, error) {
|
app.DNSMITM.RequestHook = func(clientAddr net.Addr, reqMsg dns.Msg, network string) (*dns.Msg, *dns.Msg, error) {
|
||||||
log.Debug().
|
|
||||||
Str("network", network).
|
|
||||||
Str("clientAddr", clientAddr.String()).
|
|
||||||
Str("name", reqMsg.Question[0].Name).
|
|
||||||
Msg("received DNS request")
|
|
||||||
|
|
||||||
// TODO: Need to understand why it not works in proxy mode
|
// TODO: Need to understand why it not works in proxy mode
|
||||||
if len(reqMsg.Question) == 1 && reqMsg.Question[0].Qtype == dns.TypePTR {
|
if len(reqMsg.Question) == 1 && reqMsg.Question[0].Qtype == dns.TypePTR {
|
||||||
respMsg := &dns.Msg{
|
respMsg := &dns.Msg{
|
||||||
@ -530,10 +529,6 @@ func New(config Config) (*App, error) {
|
|||||||
},
|
},
|
||||||
Question: reqMsg.Question,
|
Question: reqMsg.Question,
|
||||||
}
|
}
|
||||||
log.Debug().
|
|
||||||
Str("network", network).
|
|
||||||
Str("clientAddr", clientAddr.String()).
|
|
||||||
Msg("sending DNS response")
|
|
||||||
return nil, respMsg, nil
|
return nil, respMsg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -551,20 +546,12 @@ func New(config Config) (*App, error) {
|
|||||||
}
|
}
|
||||||
respMsg.Answer = respMsg.Answer[:idx]
|
respMsg.Answer = respMsg.Answer[:idx]
|
||||||
|
|
||||||
if len(respMsg.Answer) != 0 {
|
|
||||||
log.Debug().
|
|
||||||
Str("network", network).
|
|
||||||
Str("clientAddr", clientAddr.String()).
|
|
||||||
Str("respMsg", respMsg.Answer[0].Header().Name).
|
|
||||||
Msg("sending DNS response")
|
|
||||||
}
|
|
||||||
|
|
||||||
app.handleMessage(respMsg)
|
app.handleMessage(respMsg)
|
||||||
|
|
||||||
return &respMsg, nil
|
return &respMsg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
app.Records = NewRecords()
|
app.Records = records.New()
|
||||||
app.Groups = make(map[uuid.UUID]*Group, 0)
|
app.Groups = make(map[uuid.UUID]*Group, 0)
|
||||||
|
|
||||||
link, err := netlink.LinkByName(app.Config.LinkName)
|
link, err := netlink.LinkByName(app.Config.LinkName)
|
||||||
|
@ -179,7 +179,7 @@ func (r *IfaceToIPSet) ForceEnable() error {
|
|||||||
// IPTables rules
|
// IPTables rules
|
||||||
err = r.PutIPTable("all")
|
err = r.PutIPTable("all")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mapping mark with table
|
// Mapping mark with table
|
||||||
@ -194,7 +194,7 @@ func (r *IfaceToIPSet) ForceEnable() error {
|
|||||||
|
|
||||||
err = r.IfaceHandle()
|
err = r.IfaceHandle()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
r.Enabled = true
|
r.Enabled = true
|
||||||
|
228
records.go
228
records.go
@ -1,228 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ARecord struct {
|
|
||||||
Address net.IP
|
|
||||||
Deadline time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewARecord(addr net.IP, deadline time.Time) *ARecord {
|
|
||||||
return &ARecord{
|
|
||||||
Address: addr,
|
|
||||||
Deadline: deadline,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type CNameRecord struct {
|
|
||||||
Alias string
|
|
||||||
Deadline time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewCNameRecord(domainName string, deadline time.Time) *CNameRecord {
|
|
||||||
return &CNameRecord{
|
|
||||||
Alias: domainName,
|
|
||||||
Deadline: deadline,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type Records struct {
|
|
||||||
mutex sync.RWMutex
|
|
||||||
ARecords map[string][]*ARecord
|
|
||||||
CNameRecords map[string]*CNameRecord
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Records) cleanupARecords(now time.Time) {
|
|
||||||
for name, aRecords := range r.ARecords {
|
|
||||||
i := 0
|
|
||||||
for _, aRecord := range aRecords {
|
|
||||||
if now.After(aRecord.Deadline) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
aRecords[i] = aRecord
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
aRecords = aRecords[:i]
|
|
||||||
if i == 0 {
|
|
||||||
delete(r.ARecords, name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Records) cleanupCNameRecords(now time.Time) {
|
|
||||||
for name, record := range r.CNameRecords {
|
|
||||||
if now.After(record.Deadline) {
|
|
||||||
delete(r.CNameRecords, name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Records) getAliasedDomain(now time.Time, domainName string) string {
|
|
||||||
processedDomains := make(map[string]struct{})
|
|
||||||
for {
|
|
||||||
if _, processed := processedDomains[domainName]; processed {
|
|
||||||
// Loop detected!
|
|
||||||
return ""
|
|
||||||
} else {
|
|
||||||
processedDomains[domainName] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
cname, ok := r.CNameRecords[domainName]
|
|
||||||
if !ok {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if now.After(cname.Deadline) {
|
|
||||||
delete(r.CNameRecords, domainName)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
domainName = cname.Alias
|
|
||||||
}
|
|
||||||
return domainName
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Records) getActualARecords(now time.Time, domainName string) []*ARecord {
|
|
||||||
aRecords, ok := r.ARecords[domainName]
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
i := 0
|
|
||||||
for _, aRecord := range aRecords {
|
|
||||||
if now.After(aRecord.Deadline) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
aRecords[i] = aRecord
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
aRecords = aRecords[:i]
|
|
||||||
if i == 0 {
|
|
||||||
delete(r.ARecords, domainName)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return aRecords
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Records) getActualCNames(now time.Time, domainName string, fromEnd bool) []string {
|
|
||||||
processedDomains := make(map[string]struct{})
|
|
||||||
cNameList := make([]string, 0)
|
|
||||||
if fromEnd {
|
|
||||||
domainName = r.getAliasedDomain(now, domainName)
|
|
||||||
cNameList = append(cNameList, domainName)
|
|
||||||
}
|
|
||||||
r.cleanupCNameRecords(now)
|
|
||||||
for {
|
|
||||||
if _, processed := processedDomains[domainName]; processed {
|
|
||||||
// Loop detected!
|
|
||||||
return nil
|
|
||||||
} else {
|
|
||||||
processedDomains[domainName] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
found := false
|
|
||||||
for aliasFrom, aliasTo := range r.CNameRecords {
|
|
||||||
if aliasTo.Alias == domainName {
|
|
||||||
cNameList = append(cNameList, aliasFrom)
|
|
||||||
domainName = aliasFrom
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !found {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return cNameList
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Records) Cleanup() {
|
|
||||||
r.mutex.Lock()
|
|
||||||
defer r.mutex.Unlock()
|
|
||||||
now := time.Now()
|
|
||||||
r.cleanupARecords(now)
|
|
||||||
r.cleanupCNameRecords(now)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Records) GetCNameRecords(domainName string, fromEnd bool) []string {
|
|
||||||
r.mutex.RLock()
|
|
||||||
defer r.mutex.RUnlock()
|
|
||||||
now := time.Now()
|
|
||||||
|
|
||||||
return r.getActualCNames(now, domainName, fromEnd)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Records) GetARecords(domainName string) []*ARecord {
|
|
||||||
r.mutex.Lock()
|
|
||||||
defer r.mutex.Unlock()
|
|
||||||
now := time.Now()
|
|
||||||
|
|
||||||
return r.getActualARecords(now, r.getAliasedDomain(now, domainName))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Records) AddCNameRecord(domainName string, cName string, ttl time.Duration) {
|
|
||||||
if domainName == cName {
|
|
||||||
// Can't assing to yourself
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
r.mutex.Lock()
|
|
||||||
defer r.mutex.Unlock()
|
|
||||||
now := time.Now()
|
|
||||||
|
|
||||||
delete(r.ARecords, domainName)
|
|
||||||
r.CNameRecords[domainName] = NewCNameRecord(cName, now.Add(ttl))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Records) AddARecord(domainName string, addr net.IP, ttl time.Duration) {
|
|
||||||
r.mutex.Lock()
|
|
||||||
defer r.mutex.Unlock()
|
|
||||||
now := time.Now()
|
|
||||||
|
|
||||||
delete(r.CNameRecords, domainName)
|
|
||||||
if _, ok := r.ARecords[domainName]; !ok {
|
|
||||||
r.ARecords[domainName] = make([]*ARecord, 0)
|
|
||||||
}
|
|
||||||
for _, aRecord := range r.ARecords[domainName] {
|
|
||||||
if bytes.Compare(aRecord.Address, addr) == 0 {
|
|
||||||
aRecord.Deadline = now.Add(ttl)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
r.ARecords[domainName] = append(r.ARecords[domainName], NewARecord(addr, now.Add(ttl)))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Records) ListKnownDomains() []string {
|
|
||||||
r.mutex.Lock()
|
|
||||||
defer r.mutex.Unlock()
|
|
||||||
now := time.Now()
|
|
||||||
r.cleanupARecords(now)
|
|
||||||
r.cleanupCNameRecords(now)
|
|
||||||
|
|
||||||
domains := map[string]struct{}{}
|
|
||||||
for name, _ := range r.ARecords {
|
|
||||||
domains[name] = struct{}{}
|
|
||||||
}
|
|
||||||
for name, _ := range r.CNameRecords {
|
|
||||||
domains[name] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
domainsList := make([]string, len(domains))
|
|
||||||
i := 0
|
|
||||||
for name, _ := range domains {
|
|
||||||
domainsList[i] = name
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
return domainsList
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewRecords() *Records {
|
|
||||||
return &Records{
|
|
||||||
ARecords: make(map[string][]*ARecord),
|
|
||||||
CNameRecords: make(map[string]*CNameRecord),
|
|
||||||
}
|
|
||||||
}
|
|
167
records/records.go
Normal file
167
records/records.go
Normal file
@ -0,0 +1,167 @@
|
|||||||
|
package records
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ARecord struct {
|
||||||
|
Address net.IP
|
||||||
|
Deadline time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type CNameRecord struct {
|
||||||
|
Alias string
|
||||||
|
Deadline time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type Records struct {
|
||||||
|
mux sync.RWMutex
|
||||||
|
records map[string]interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Records) AddCNameRecord(domainName, alias string, ttl time.Duration) {
|
||||||
|
if domainName == alias {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
r.mux.Lock()
|
||||||
|
r.records[domainName] = &CNameRecord{
|
||||||
|
Alias: alias,
|
||||||
|
Deadline: time.Now().Add(ttl),
|
||||||
|
}
|
||||||
|
r.mux.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Records) AddARecord(domainName string, addr net.IP, ttl time.Duration) {
|
||||||
|
r.mux.Lock()
|
||||||
|
defer r.mux.Unlock()
|
||||||
|
|
||||||
|
deadline := time.Now().Add(ttl)
|
||||||
|
|
||||||
|
aRecords, _ := r.records[domainName].([]*ARecord)
|
||||||
|
for _, aRecord := range aRecords {
|
||||||
|
if bytes.Compare(aRecord.Address, addr) != 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
aRecord.Deadline = deadline
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
r.records[domainName] = append(aRecords, &ARecord{
|
||||||
|
Address: addr,
|
||||||
|
Deadline: deadline,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Records) GetAliases(domainName string) []string {
|
||||||
|
r.mux.Lock()
|
||||||
|
defer r.mux.Unlock()
|
||||||
|
r.cleanupRecords()
|
||||||
|
|
||||||
|
domains := make(map[string]struct{})
|
||||||
|
domains[domainName] = struct{}{}
|
||||||
|
|
||||||
|
for {
|
||||||
|
var addedNew bool
|
||||||
|
for name, aRecord := range r.records {
|
||||||
|
if _, ok := domains[name]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
cname, ok := aRecord.(*CNameRecord)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok = domains[cname.Alias]; !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
domains[name] = struct{}{}
|
||||||
|
addedNew = true
|
||||||
|
}
|
||||||
|
if !addedNew {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
domainList := make([]string, len(domains))
|
||||||
|
idx := 0
|
||||||
|
for name, _ := range domains {
|
||||||
|
domainList[idx] = name
|
||||||
|
idx++
|
||||||
|
}
|
||||||
|
|
||||||
|
return domainList
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Records) GetARecords(domainName string) []*ARecord {
|
||||||
|
r.mux.Lock()
|
||||||
|
defer r.mux.Unlock()
|
||||||
|
r.cleanupRecords()
|
||||||
|
|
||||||
|
loopDetect := make(map[string]struct{})
|
||||||
|
loopDetect[domainName] = struct{}{}
|
||||||
|
for {
|
||||||
|
switch v := r.records[domainName].(type) {
|
||||||
|
case *CNameRecord:
|
||||||
|
if _, ok := loopDetect[v.Alias]; ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
domainName = v.Alias
|
||||||
|
loopDetect[v.Alias] = struct{}{}
|
||||||
|
case []*ARecord:
|
||||||
|
return v
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Records) ListKnownDomains() []string {
|
||||||
|
r.mux.Lock()
|
||||||
|
defer r.mux.Unlock()
|
||||||
|
r.cleanupRecords()
|
||||||
|
|
||||||
|
domainsList := make([]string, len(r.records))
|
||||||
|
i := 0
|
||||||
|
for name, _ := range r.records {
|
||||||
|
domainsList[i] = name
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
return domainsList
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Records) cleanupRecords() {
|
||||||
|
now := time.Now()
|
||||||
|
for name, records := range r.records {
|
||||||
|
switch v := records.(type) {
|
||||||
|
case []*ARecord:
|
||||||
|
idx := 0
|
||||||
|
for _, aRecord := range v {
|
||||||
|
if now.After(aRecord.Deadline) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
v[idx] = aRecord
|
||||||
|
idx++
|
||||||
|
}
|
||||||
|
if idx == 0 {
|
||||||
|
delete(r.records, name)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
r.records[name] = v[:idx]
|
||||||
|
case *CNameRecord:
|
||||||
|
if !now.After(v.Deadline) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
delete(r.records, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func New() *Records {
|
||||||
|
return &Records{
|
||||||
|
records: make(map[string]interface{}),
|
||||||
|
}
|
||||||
|
}
|
109
records/records_test.go
Normal file
109
records/records_test.go
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
package records
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"slices"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLoop(t *testing.T) {
|
||||||
|
r := New()
|
||||||
|
r.AddCNameRecord("1", "2", time.Minute)
|
||||||
|
r.AddCNameRecord("2", "1", time.Minute)
|
||||||
|
if r.GetARecords("1") != nil {
|
||||||
|
t.Fatal("loop detected")
|
||||||
|
}
|
||||||
|
if r.GetARecords("2") != nil {
|
||||||
|
t.Fatal("loop detected")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCName(t *testing.T) {
|
||||||
|
r := New()
|
||||||
|
r.AddARecord("example.com", []byte{1, 2, 3, 4}, time.Minute)
|
||||||
|
r.AddCNameRecord("gateway.example.com", "example.com", time.Minute)
|
||||||
|
records := r.GetARecords("gateway.example.com")
|
||||||
|
if records == nil {
|
||||||
|
t.Fatal("no records")
|
||||||
|
}
|
||||||
|
if bytes.Compare(records[0].Address, []byte{1, 2, 3, 4}) != 0 {
|
||||||
|
t.Fatal("cname mismatch")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestA(t *testing.T) {
|
||||||
|
r := New()
|
||||||
|
r.AddARecord("example.com", []byte{1, 2, 3, 4}, time.Minute)
|
||||||
|
records := r.GetARecords("example.com")
|
||||||
|
if records == nil {
|
||||||
|
t.Fatal("no records")
|
||||||
|
}
|
||||||
|
if bytes.Compare(records[0].Address, []byte{1, 2, 3, 4}) != 0 {
|
||||||
|
t.Fatal("cname mismatch")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeprecated(t *testing.T) {
|
||||||
|
r := New()
|
||||||
|
r.AddARecord("example.com", []byte{1, 2, 3, 4}, -time.Minute)
|
||||||
|
records := r.GetARecords("example.com")
|
||||||
|
if records != nil {
|
||||||
|
t.Fatal("deprecated records")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNotExistedA(t *testing.T) {
|
||||||
|
r := New()
|
||||||
|
records := r.GetARecords("example.com")
|
||||||
|
if records != nil {
|
||||||
|
t.Fatal("not existed records")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNotExistedCNameAlias(t *testing.T) {
|
||||||
|
r := New()
|
||||||
|
r.AddCNameRecord("gateway.example.com", "example.com", time.Minute)
|
||||||
|
records := r.GetARecords("gateway.example.com")
|
||||||
|
if records != nil {
|
||||||
|
t.Fatal("not existed records")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReplacing(t *testing.T) {
|
||||||
|
r := New()
|
||||||
|
r.AddCNameRecord("gateway.example.com", "example.com", time.Minute)
|
||||||
|
r.AddARecord("gateway.example.com", []byte{1, 2, 3, 4}, time.Minute)
|
||||||
|
records := r.GetARecords("gateway.example.com")
|
||||||
|
if bytes.Compare(records[0].Address, []byte{1, 2, 3, 4}) != 0 {
|
||||||
|
t.Fatal("mismatch")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAliases(t *testing.T) {
|
||||||
|
r := New()
|
||||||
|
r.AddARecord("1", []byte{1, 2, 3, 4}, time.Minute)
|
||||||
|
r.AddCNameRecord("2", "1", time.Minute)
|
||||||
|
r.AddCNameRecord("3", "2", time.Minute)
|
||||||
|
r.AddCNameRecord("4", "2", time.Minute)
|
||||||
|
r.AddCNameRecord("5", "1", time.Minute)
|
||||||
|
aliases := r.GetAliases("1")
|
||||||
|
if aliases == nil {
|
||||||
|
t.Fatal("no aliases")
|
||||||
|
}
|
||||||
|
if !slices.Contains(aliases, "1") {
|
||||||
|
t.Fatal("no 1")
|
||||||
|
}
|
||||||
|
if !slices.Contains(aliases, "2") {
|
||||||
|
t.Fatal("no 2")
|
||||||
|
}
|
||||||
|
if !slices.Contains(aliases, "3") {
|
||||||
|
t.Fatal("no 3")
|
||||||
|
}
|
||||||
|
if !slices.Contains(aliases, "4") {
|
||||||
|
t.Fatal("no 4")
|
||||||
|
}
|
||||||
|
if !slices.Contains(aliases, "5") {
|
||||||
|
t.Fatal("no 5")
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user