This commit is contained in:
Vladimir Avtsenov 2025-02-11 02:50:13 +03:00
parent cf078c330c
commit 5bc0c3b2b4
6 changed files with 355 additions and 338 deletions

View File

@ -1,19 +1,17 @@
package dnsMitm
package dnsMitmProxy
import (
"context"
"encoding/binary"
"fmt"
"net"
"strconv"
"time"
"github.com/miekg/dns"
"github.com/rs/zerolog/log"
)
type DNSMITM struct {
ListenPort uint16
type DNSMITMProxy struct {
TargetDNSServerAddress string
TargetDNSServerPort uint16
@ -21,7 +19,7 @@ type DNSMITM struct {
ResponseHook func(net.Addr, dns.Msg, dns.Msg, string) (*dns.Msg, error)
}
func (p DNSMITM) requestDNS(req []byte, network string) ([]byte, error) {
func (p DNSMITMProxy) requestDNS(req []byte, network string) ([]byte, error) {
serverConn, err := net.Dial(network, fmt.Sprintf("%s:%d", p.TargetDNSServerAddress, p.TargetDNSServerPort))
if err != nil {
return nil, fmt.Errorf("failed to dial DNS server: %w", err)
@ -65,7 +63,7 @@ func (p DNSMITM) requestDNS(req []byte, network string) ([]byte, error) {
return resp[:n], nil
}
func (p DNSMITM) processReq(clientAddr net.Addr, req []byte, network string) ([]byte, error) {
func (p DNSMITMProxy) processReq(clientAddr net.Addr, req []byte, network string) ([]byte, error) {
var reqMsg dns.Msg
if p.RequestHook != nil || p.ResponseHook != nil {
err := reqMsg.Unpack(req)
@ -97,14 +95,14 @@ func (p DNSMITM) processReq(clientAddr net.Addr, req []byte, network string) ([]
resp, err := p.requestDNS(req, network)
if err != nil {
return nil, fmt.Errorf("failed to send request")
return nil, fmt.Errorf("failed to send request: %w", err)
}
if p.ResponseHook != nil {
var respMsg dns.Msg
err = respMsg.Unpack(resp)
if err != nil {
return nil, fmt.Errorf("failed to parse response")
return nil, fmt.Errorf("failed to parse response: %w", err)
}
modifiedResp, err := p.ResponseHook(clientAddr, reqMsg, respMsg, network)
@ -123,12 +121,7 @@ func (p DNSMITM) processReq(clientAddr net.Addr, req []byte, network string) ([]
return resp, nil
}
func (p DNSMITM) ListenTCP(ctx context.Context) error {
addr, err := net.ResolveTCPAddr("tcp", "[::]:"+strconv.Itoa(int(p.ListenPort)))
if err != nil {
return fmt.Errorf("failed to resolve tcp address: %v", err)
}
func (p DNSMITMProxy) ListenTCP(ctx context.Context, addr *net.TCPAddr) error {
listener, err := net.ListenTCP("tcp", addr)
if err != nil {
return fmt.Errorf("failed to listen tcp port: %v", err)
@ -184,12 +177,7 @@ func (p DNSMITM) ListenTCP(ctx context.Context) error {
}
}
func (p DNSMITM) ListenUDP(ctx context.Context) error {
addr, err := net.ResolveUDPAddr("udp", "[::]:"+strconv.Itoa(int(p.ListenPort)))
if err != nil {
return fmt.Errorf("failed to resolve udp address: %v", err)
}
func (p DNSMITMProxy) ListenUDP(ctx context.Context, addr *net.UDPAddr) error {
conn, err := net.ListenUDP("udp", addr)
if err != nil {
return fmt.Errorf("failed to listen udp port: %v", err)
@ -226,14 +214,8 @@ func (p DNSMITM) ListenUDP(ctx context.Context) error {
}
}
func New(listenPort uint16, targetDNSServerAddress string, targetDNSServerPort ...uint16) *DNSMITM {
dnsMitm := &DNSMITM{
ListenPort: listenPort,
TargetDNSServerAddress: targetDNSServerAddress,
TargetDNSServerPort: 53,
func New() *DNSMITMProxy {
return &DNSMITMProxy{
TargetDNSServerPort: 53,
}
if len(targetDNSServerPort) > 0 {
dnsMitm.TargetDNSServerPort = targetDNSServerPort[0]
}
return dnsMitm
}

145
kvas2.go
View File

@ -6,19 +6,20 @@ import (
"fmt"
"net"
"os"
"strconv"
"strings"
"time"
"kvas2-go/dns-mitm"
"kvas2-go/dns-mitm-proxy"
"kvas2-go/models"
"kvas2-go/netfilter-helper"
"kvas2-go/records"
"github.com/google/uuid"
"github.com/miekg/dns"
"github.com/rs/zerolog/log"
"github.com/vishvananda/netlink"
"github.com/vishvananda/netlink/nl"
"golang.org/x/sys/unix"
)
var (
@ -39,10 +40,10 @@ type Config struct {
type App struct {
Config Config
DNSMITM *dnsMitm.DNSMITM
DNSMITM *dnsMitmProxy.DNSMITMProxy
NetfilterHelper4 *netfilterHelper.NetfilterHelper
NetfilterHelper6 *netfilterHelper.NetfilterHelper
Records *Records
Records *records.Records
Groups map[uuid.UUID]*Group
Link netlink.Link
@ -54,7 +55,7 @@ type App struct {
func (a *App) handleLink(event netlink.LinkUpdate) {
switch event.Change {
case unix.IFF_UP:
case 0x00000001:
log.Debug().
Str("interface", event.Link.Attrs().Name).
Str("operstatestr", event.Attrs().OperState.String()).
@ -94,7 +95,6 @@ func (a *App) start(ctx context.Context) (err error) {
newCtx, cancel := context.WithCancel(ctx)
defer cancel()
// TODO: Chan err
errChan := make(chan error)
/*
@ -102,16 +102,28 @@ func (a *App) start(ctx context.Context) (err error) {
*/
go func() {
err := a.DNSMITM.ListenUDP(newCtx)
addr, err := net.ResolveUDPAddr("udp", "[::]:"+strconv.Itoa(int(a.Config.ListenDNSPort)))
if err != nil {
errChan <- fmt.Errorf("failed to resolve udp address: %v", err)
return
}
err = a.DNSMITM.ListenUDP(newCtx, addr)
if err != nil {
errChan <- fmt.Errorf("failed to serve DNS UDP proxy: %v", err)
return
}
}()
go func() {
err := a.DNSMITM.ListenTCP(newCtx)
addr, err := net.ResolveTCPAddr("tcp", "[::]:"+strconv.Itoa(int(a.Config.ListenDNSPort)))
if err != nil {
errChan <- fmt.Errorf("failed to resolve tcp address: %v", err)
return
}
err = a.DNSMITM.ListenTCP(newCtx, addr)
if err != nil {
errChan <- fmt.Errorf("failed to serve DNS TCP proxy: %v", err)
return
}
}()
@ -125,20 +137,14 @@ func (a *App) start(ctx context.Context) (err error) {
if err != nil {
return fmt.Errorf("failed to override DNS (IPv4): %v", err)
}
defer func() {
// TODO: Handle error
_ = a.dnsOverrider4.Disable()
}()
defer func() { _ = a.dnsOverrider4.Disable() }()
a.dnsOverrider6 = a.NetfilterHelper6.PortRemap(fmt.Sprintf("%sDNSOR", a.Config.ChainPrefix), 53, a.Config.ListenDNSPort, addrList)
err = a.dnsOverrider6.Enable()
if err != nil {
return fmt.Errorf("failed to override DNS (IPv6): %v", err)
}
defer func() {
// TODO: Handle error
_ = a.dnsOverrider6.Disable()
}()
defer func() { _ = a.dnsOverrider6.Disable() }()
/*
Groups
@ -152,7 +158,6 @@ func (a *App) start(ctx context.Context) (err error) {
}
defer func() {
for _, group := range a.Groups {
// TODO: Handle error
_ = group.Disable()
}
}()
@ -170,13 +175,16 @@ func (a *App) start(ctx context.Context) (err error) {
return fmt.Errorf("error while serve UNIX socket: %v", err)
}
defer func() {
// TODO: Handle error
_ = socket.Close()
_ = os.Remove(socketPath)
}()
go func() {
for {
if newCtx.Err() != nil {
return
}
conn, err := socket.Accept()
if err != nil {
if !strings.Contains(err.Error(), "use of closed network connection") {
@ -186,10 +194,7 @@ func (a *App) start(ctx context.Context) (err error) {
}
go func(conn net.Conn) {
defer func() {
// TODO: Handle error
_ = conn.Close()
}()
defer func() { _ = conn.Close() }()
buf := make([]byte, 1024)
n, err := conn.Read(buf)
@ -206,6 +211,12 @@ func (a *App) start(ctx context.Context) (err error) {
log.Error().Err(err).Msg("error while fixing iptables after netfilter.d")
}
}
if a.dnsOverrider6.Enabled {
err = a.dnsOverrider6.PutIPTable(args[2])
if err != nil {
log.Error().Err(err).Msg("error while fixing iptables after netfilter.d")
}
}
for _, group := range a.Groups {
if group.ifaceToIPSetNAT.Enabled {
err := group.ifaceToIPSetNAT.PutIPTable(args[2])
@ -240,7 +251,6 @@ func (a *App) start(ctx context.Context) (err error) {
case event := <-linkUpdateChannel:
a.handleLink(event)
case err := <-errChan:
close(errChan)
return err
case <-ctx.Done():
return nil
@ -288,22 +298,17 @@ func (a *App) AddGroup(group *models.Group) error {
Group: group,
iptables: a.NetfilterHelper4.IPTables,
ipset: ipset,
ifaceToIPSetNAT: a.NetfilterHelper4.IfaceToIPSet(fmt.Sprintf("%sR_%d", a.Config.ChainPrefix, group.ID), group.Interface, ipsetName, false),
ifaceToIPSetNAT: a.NetfilterHelper4.IfaceToIPSet(fmt.Sprintf("%sR_%8x", a.Config.ChainPrefix, group.ID.ID()), group.Interface, ipsetName, false),
}
grp.ifaceToIPSetNAT.SoftwareMode = a.Config.UseSoftwareRouting
a.Groups[grp.ID] = grp
return a.SyncGroup(grp)
}
func (a *App) SyncGroup(group *Group) error {
processedDomains := make(map[string]struct{})
newIpsetAddressesMap := make(map[string]time.Duration)
now := time.Now()
oldIpsetAddresses, err := group.ListIPv4()
if err != nil {
return fmt.Errorf("failed to get old ipset list: %w", err)
}
addresses := make(map[string]time.Duration)
knownDomains := a.Records.ListKnownDomains()
for _, domain := range group.Rules {
if !domain.IsEnabled() {
@ -315,26 +320,24 @@ func (a *App) SyncGroup(group *Group) error {
continue
}
cnames := a.Records.GetCNameRecords(domainName, true)
if len(cnames) == 0 {
continue
}
for _, cname := range cnames {
processedDomains[cname] = struct{}{}
}
addresses := a.Records.GetARecords(domainName)
for _, address := range addresses {
domainAddresses := a.Records.GetARecords(domainName)
for _, address := range domainAddresses {
ttl := now.Sub(address.Deadline)
if oldTTL, ok := newIpsetAddressesMap[string(address.Address)]; !ok || ttl > oldTTL {
newIpsetAddressesMap[string(address.Address)] = ttl
if oldTTL, ok := addresses[string(address.Address)]; !ok || ttl > oldTTL {
addresses[string(address.Address)] = ttl
}
}
}
}
for addr, ttl := range newIpsetAddressesMap {
if _, exists := oldIpsetAddresses[addr]; exists {
currentAddresses, err := group.ListIPv4()
if err != nil {
return fmt.Errorf("failed to get old ipset list: %w", err)
}
for addr, ttl := range addresses {
// TODO: Check TTL
if _, exists := currentAddresses[addr]; exists {
continue
}
ip := net.IP(addr)
@ -344,11 +347,16 @@ func (a *App) SyncGroup(group *Group) error {
Str("address", ip.String()).
Err(err).
Msg("failed to add address")
} else {
log.Trace().
Str("address", ip.String()).
Err(err).
Msg("add address")
}
}
for addr := range oldIpsetAddresses {
if _, exists := newIpsetAddressesMap[addr]; exists {
for addr := range currentAddresses {
if _, ok := addresses[addr]; ok {
continue
}
ip := net.IP(addr)
@ -362,7 +370,7 @@ func (a *App) SyncGroup(group *Group) error {
log.Trace().
Str("address", ip.String()).
Err(err).
Msg("add address")
Msg("del address")
}
}
@ -400,9 +408,9 @@ func (a *App) processARecord(aRecord dns.A) {
ttlDuration = a.Config.MinimalTTL
}
a.Records.AddARecord(aRecord.Hdr.Name, aRecord.A, ttlDuration)
a.Records.AddARecord(aRecord.Hdr.Name[:len(aRecord.Hdr.Name)-1], aRecord.A, ttlDuration)
names := a.Records.GetCNameRecords(aRecord.Hdr.Name, true)
names := a.Records.GetAliases(aRecord.Hdr.Name[:len(aRecord.Hdr.Name)-1])
for _, group := range a.Groups {
Rule:
for _, domain := range group.Rules {
@ -413,6 +421,7 @@ func (a *App) processARecord(aRecord dns.A) {
if !domain.IsMatch(name) {
continue
}
// TODO: Check already existed
err := group.AddIPv4(aRecord.A, ttlDuration)
if err != nil {
log.Error().
@ -445,12 +454,12 @@ func (a *App) processCNameRecord(cNameRecord dns.CNAME) {
ttlDuration = a.Config.MinimalTTL
}
a.Records.AddCNameRecord(cNameRecord.Hdr.Name, cNameRecord.Target, ttlDuration)
a.Records.AddCNameRecord(cNameRecord.Hdr.Name[:len(cNameRecord.Hdr.Name)-1], cNameRecord.Target, ttlDuration)
// TODO: Optimization
now := time.Now()
aRecords := a.Records.GetARecords(cNameRecord.Hdr.Name)
names := a.Records.GetCNameRecords(cNameRecord.Hdr.Name, true)
aRecords := a.Records.GetARecords(cNameRecord.Hdr.Name[:len(cNameRecord.Hdr.Name)-1])
names := a.Records.GetAliases(cNameRecord.Hdr.Name[:len(cNameRecord.Hdr.Name)-1])
for _, group := range a.Groups {
Rule:
for _, domain := range group.Rules {
@ -496,12 +505,6 @@ func (a *App) handleMessage(msg dns.Msg) {
for _, rr := range msg.Answer {
a.handleRecord(rr)
}
for _, rr := range msg.Ns {
a.handleRecord(rr)
}
for _, rr := range msg.Extra {
a.handleRecord(rr)
}
}
func New(config Config) (*App, error) {
@ -511,14 +514,10 @@ func New(config Config) (*App, error) {
app.Config = config
app.DNSMITM = dnsMitm.New(app.Config.ListenDNSPort, app.Config.TargetDNSServerAddress)
app.DNSMITM = dnsMitmProxy.New()
app.DNSMITM.TargetDNSServerAddress = app.Config.TargetDNSServerAddress
app.DNSMITM.TargetDNSServerPort = 53
app.DNSMITM.RequestHook = func(clientAddr net.Addr, reqMsg dns.Msg, network string) (*dns.Msg, *dns.Msg, error) {
log.Debug().
Str("network", network).
Str("clientAddr", clientAddr.String()).
Str("name", reqMsg.Question[0].Name).
Msg("received DNS request")
// TODO: Need to understand why it not works in proxy mode
if len(reqMsg.Question) == 1 && reqMsg.Question[0].Qtype == dns.TypePTR {
respMsg := &dns.Msg{
@ -530,10 +529,6 @@ func New(config Config) (*App, error) {
},
Question: reqMsg.Question,
}
log.Debug().
Str("network", network).
Str("clientAddr", clientAddr.String()).
Msg("sending DNS response")
return nil, respMsg, nil
}
@ -551,20 +546,12 @@ func New(config Config) (*App, error) {
}
respMsg.Answer = respMsg.Answer[:idx]
if len(respMsg.Answer) != 0 {
log.Debug().
Str("network", network).
Str("clientAddr", clientAddr.String()).
Str("respMsg", respMsg.Answer[0].Header().Name).
Msg("sending DNS response")
}
app.handleMessage(respMsg)
return &respMsg, nil
}
app.Records = NewRecords()
app.Records = records.New()
app.Groups = make(map[uuid.UUID]*Group, 0)
link, err := netlink.LinkByName(app.Config.LinkName)

View File

@ -179,7 +179,7 @@ func (r *IfaceToIPSet) ForceEnable() error {
// IPTables rules
err = r.PutIPTable("all")
if err != nil {
return nil
return err
}
// Mapping mark with table
@ -194,7 +194,7 @@ func (r *IfaceToIPSet) ForceEnable() error {
err = r.IfaceHandle()
if err != nil {
return nil
return err
}
r.Enabled = true

View File

@ -1,228 +0,0 @@
package main
import (
"bytes"
"net"
"sync"
"time"
)
type ARecord struct {
Address net.IP
Deadline time.Time
}
func NewARecord(addr net.IP, deadline time.Time) *ARecord {
return &ARecord{
Address: addr,
Deadline: deadline,
}
}
type CNameRecord struct {
Alias string
Deadline time.Time
}
func NewCNameRecord(domainName string, deadline time.Time) *CNameRecord {
return &CNameRecord{
Alias: domainName,
Deadline: deadline,
}
}
type Records struct {
mutex sync.RWMutex
ARecords map[string][]*ARecord
CNameRecords map[string]*CNameRecord
}
func (r *Records) cleanupARecords(now time.Time) {
for name, aRecords := range r.ARecords {
i := 0
for _, aRecord := range aRecords {
if now.After(aRecord.Deadline) {
continue
}
aRecords[i] = aRecord
i++
}
aRecords = aRecords[:i]
if i == 0 {
delete(r.ARecords, name)
}
}
}
func (r *Records) cleanupCNameRecords(now time.Time) {
for name, record := range r.CNameRecords {
if now.After(record.Deadline) {
delete(r.CNameRecords, name)
}
}
}
func (r *Records) getAliasedDomain(now time.Time, domainName string) string {
processedDomains := make(map[string]struct{})
for {
if _, processed := processedDomains[domainName]; processed {
// Loop detected!
return ""
} else {
processedDomains[domainName] = struct{}{}
}
cname, ok := r.CNameRecords[domainName]
if !ok {
break
}
if now.After(cname.Deadline) {
delete(r.CNameRecords, domainName)
break
}
domainName = cname.Alias
}
return domainName
}
func (r *Records) getActualARecords(now time.Time, domainName string) []*ARecord {
aRecords, ok := r.ARecords[domainName]
if !ok {
return nil
}
i := 0
for _, aRecord := range aRecords {
if now.After(aRecord.Deadline) {
continue
}
aRecords[i] = aRecord
i++
}
aRecords = aRecords[:i]
if i == 0 {
delete(r.ARecords, domainName)
return nil
}
return aRecords
}
func (r *Records) getActualCNames(now time.Time, domainName string, fromEnd bool) []string {
processedDomains := make(map[string]struct{})
cNameList := make([]string, 0)
if fromEnd {
domainName = r.getAliasedDomain(now, domainName)
cNameList = append(cNameList, domainName)
}
r.cleanupCNameRecords(now)
for {
if _, processed := processedDomains[domainName]; processed {
// Loop detected!
return nil
} else {
processedDomains[domainName] = struct{}{}
}
found := false
for aliasFrom, aliasTo := range r.CNameRecords {
if aliasTo.Alias == domainName {
cNameList = append(cNameList, aliasFrom)
domainName = aliasFrom
found = true
break
}
}
if !found {
break
}
}
return cNameList
}
func (r *Records) Cleanup() {
r.mutex.Lock()
defer r.mutex.Unlock()
now := time.Now()
r.cleanupARecords(now)
r.cleanupCNameRecords(now)
}
func (r *Records) GetCNameRecords(domainName string, fromEnd bool) []string {
r.mutex.RLock()
defer r.mutex.RUnlock()
now := time.Now()
return r.getActualCNames(now, domainName, fromEnd)
}
func (r *Records) GetARecords(domainName string) []*ARecord {
r.mutex.Lock()
defer r.mutex.Unlock()
now := time.Now()
return r.getActualARecords(now, r.getAliasedDomain(now, domainName))
}
func (r *Records) AddCNameRecord(domainName string, cName string, ttl time.Duration) {
if domainName == cName {
// Can't assing to yourself
return
}
r.mutex.Lock()
defer r.mutex.Unlock()
now := time.Now()
delete(r.ARecords, domainName)
r.CNameRecords[domainName] = NewCNameRecord(cName, now.Add(ttl))
}
func (r *Records) AddARecord(domainName string, addr net.IP, ttl time.Duration) {
r.mutex.Lock()
defer r.mutex.Unlock()
now := time.Now()
delete(r.CNameRecords, domainName)
if _, ok := r.ARecords[domainName]; !ok {
r.ARecords[domainName] = make([]*ARecord, 0)
}
for _, aRecord := range r.ARecords[domainName] {
if bytes.Compare(aRecord.Address, addr) == 0 {
aRecord.Deadline = now.Add(ttl)
return
}
}
r.ARecords[domainName] = append(r.ARecords[domainName], NewARecord(addr, now.Add(ttl)))
}
func (r *Records) ListKnownDomains() []string {
r.mutex.Lock()
defer r.mutex.Unlock()
now := time.Now()
r.cleanupARecords(now)
r.cleanupCNameRecords(now)
domains := map[string]struct{}{}
for name, _ := range r.ARecords {
domains[name] = struct{}{}
}
for name, _ := range r.CNameRecords {
domains[name] = struct{}{}
}
domainsList := make([]string, len(domains))
i := 0
for name, _ := range domains {
domainsList[i] = name
i++
}
return domainsList
}
func NewRecords() *Records {
return &Records{
ARecords: make(map[string][]*ARecord),
CNameRecords: make(map[string]*CNameRecord),
}
}

167
records/records.go Normal file
View File

@ -0,0 +1,167 @@
package records
import (
"bytes"
"net"
"sync"
"time"
)
type ARecord struct {
Address net.IP
Deadline time.Time
}
type CNameRecord struct {
Alias string
Deadline time.Time
}
type Records struct {
mux sync.RWMutex
records map[string]interface{}
}
func (r *Records) AddCNameRecord(domainName, alias string, ttl time.Duration) {
if domainName == alias {
return
}
r.mux.Lock()
r.records[domainName] = &CNameRecord{
Alias: alias,
Deadline: time.Now().Add(ttl),
}
r.mux.Unlock()
}
func (r *Records) AddARecord(domainName string, addr net.IP, ttl time.Duration) {
r.mux.Lock()
defer r.mux.Unlock()
deadline := time.Now().Add(ttl)
aRecords, _ := r.records[domainName].([]*ARecord)
for _, aRecord := range aRecords {
if bytes.Compare(aRecord.Address, addr) != 0 {
continue
}
aRecord.Deadline = deadline
return
}
r.records[domainName] = append(aRecords, &ARecord{
Address: addr,
Deadline: deadline,
})
}
func (r *Records) GetAliases(domainName string) []string {
r.mux.Lock()
defer r.mux.Unlock()
r.cleanupRecords()
domains := make(map[string]struct{})
domains[domainName] = struct{}{}
for {
var addedNew bool
for name, aRecord := range r.records {
if _, ok := domains[name]; ok {
continue
}
cname, ok := aRecord.(*CNameRecord)
if !ok {
continue
}
if _, ok = domains[cname.Alias]; !ok {
continue
}
domains[name] = struct{}{}
addedNew = true
}
if !addedNew {
break
}
}
domainList := make([]string, len(domains))
idx := 0
for name, _ := range domains {
domainList[idx] = name
idx++
}
return domainList
}
func (r *Records) GetARecords(domainName string) []*ARecord {
r.mux.Lock()
defer r.mux.Unlock()
r.cleanupRecords()
loopDetect := make(map[string]struct{})
loopDetect[domainName] = struct{}{}
for {
switch v := r.records[domainName].(type) {
case *CNameRecord:
if _, ok := loopDetect[v.Alias]; ok {
return nil
}
domainName = v.Alias
loopDetect[v.Alias] = struct{}{}
case []*ARecord:
return v
default:
return nil
}
}
}
func (r *Records) ListKnownDomains() []string {
r.mux.Lock()
defer r.mux.Unlock()
r.cleanupRecords()
domainsList := make([]string, len(r.records))
i := 0
for name, _ := range r.records {
domainsList[i] = name
i++
}
return domainsList
}
func (r *Records) cleanupRecords() {
now := time.Now()
for name, records := range r.records {
switch v := records.(type) {
case []*ARecord:
idx := 0
for _, aRecord := range v {
if now.After(aRecord.Deadline) {
continue
}
v[idx] = aRecord
idx++
}
if idx == 0 {
delete(r.records, name)
break
}
r.records[name] = v[:idx]
case *CNameRecord:
if !now.After(v.Deadline) {
continue
}
delete(r.records, name)
}
}
}
func New() *Records {
return &Records{
records: make(map[string]interface{}),
}
}

109
records/records_test.go Normal file
View File

@ -0,0 +1,109 @@
package records
import (
"bytes"
"slices"
"testing"
"time"
)
func TestLoop(t *testing.T) {
r := New()
r.AddCNameRecord("1", "2", time.Minute)
r.AddCNameRecord("2", "1", time.Minute)
if r.GetARecords("1") != nil {
t.Fatal("loop detected")
}
if r.GetARecords("2") != nil {
t.Fatal("loop detected")
}
}
func TestCName(t *testing.T) {
r := New()
r.AddARecord("example.com", []byte{1, 2, 3, 4}, time.Minute)
r.AddCNameRecord("gateway.example.com", "example.com", time.Minute)
records := r.GetARecords("gateway.example.com")
if records == nil {
t.Fatal("no records")
}
if bytes.Compare(records[0].Address, []byte{1, 2, 3, 4}) != 0 {
t.Fatal("cname mismatch")
}
}
func TestA(t *testing.T) {
r := New()
r.AddARecord("example.com", []byte{1, 2, 3, 4}, time.Minute)
records := r.GetARecords("example.com")
if records == nil {
t.Fatal("no records")
}
if bytes.Compare(records[0].Address, []byte{1, 2, 3, 4}) != 0 {
t.Fatal("cname mismatch")
}
}
func TestDeprecated(t *testing.T) {
r := New()
r.AddARecord("example.com", []byte{1, 2, 3, 4}, -time.Minute)
records := r.GetARecords("example.com")
if records != nil {
t.Fatal("deprecated records")
}
}
func TestNotExistedA(t *testing.T) {
r := New()
records := r.GetARecords("example.com")
if records != nil {
t.Fatal("not existed records")
}
}
func TestNotExistedCNameAlias(t *testing.T) {
r := New()
r.AddCNameRecord("gateway.example.com", "example.com", time.Minute)
records := r.GetARecords("gateway.example.com")
if records != nil {
t.Fatal("not existed records")
}
}
func TestReplacing(t *testing.T) {
r := New()
r.AddCNameRecord("gateway.example.com", "example.com", time.Minute)
r.AddARecord("gateway.example.com", []byte{1, 2, 3, 4}, time.Minute)
records := r.GetARecords("gateway.example.com")
if bytes.Compare(records[0].Address, []byte{1, 2, 3, 4}) != 0 {
t.Fatal("mismatch")
}
}
func TestAliases(t *testing.T) {
r := New()
r.AddARecord("1", []byte{1, 2, 3, 4}, time.Minute)
r.AddCNameRecord("2", "1", time.Minute)
r.AddCNameRecord("3", "2", time.Minute)
r.AddCNameRecord("4", "2", time.Minute)
r.AddCNameRecord("5", "1", time.Minute)
aliases := r.GetAliases("1")
if aliases == nil {
t.Fatal("no aliases")
}
if !slices.Contains(aliases, "1") {
t.Fatal("no 1")
}
if !slices.Contains(aliases, "2") {
t.Fatal("no 2")
}
if !slices.Contains(aliases, "3") {
t.Fatal("no 3")
}
if !slices.Contains(aliases, "4") {
t.Fatal("no 4")
}
if !slices.Contains(aliases, "5") {
t.Fatal("no 5")
}
}