fixes
This commit is contained in:
parent
cf078c330c
commit
5bc0c3b2b4
@ -1,19 +1,17 @@
|
||||
package dnsMitm
|
||||
package dnsMitmProxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type DNSMITM struct {
|
||||
ListenPort uint16
|
||||
type DNSMITMProxy struct {
|
||||
TargetDNSServerAddress string
|
||||
TargetDNSServerPort uint16
|
||||
|
||||
@ -21,7 +19,7 @@ type DNSMITM struct {
|
||||
ResponseHook func(net.Addr, dns.Msg, dns.Msg, string) (*dns.Msg, error)
|
||||
}
|
||||
|
||||
func (p DNSMITM) requestDNS(req []byte, network string) ([]byte, error) {
|
||||
func (p DNSMITMProxy) requestDNS(req []byte, network string) ([]byte, error) {
|
||||
serverConn, err := net.Dial(network, fmt.Sprintf("%s:%d", p.TargetDNSServerAddress, p.TargetDNSServerPort))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to dial DNS server: %w", err)
|
||||
@ -65,7 +63,7 @@ func (p DNSMITM) requestDNS(req []byte, network string) ([]byte, error) {
|
||||
return resp[:n], nil
|
||||
}
|
||||
|
||||
func (p DNSMITM) processReq(clientAddr net.Addr, req []byte, network string) ([]byte, error) {
|
||||
func (p DNSMITMProxy) processReq(clientAddr net.Addr, req []byte, network string) ([]byte, error) {
|
||||
var reqMsg dns.Msg
|
||||
if p.RequestHook != nil || p.ResponseHook != nil {
|
||||
err := reqMsg.Unpack(req)
|
||||
@ -97,14 +95,14 @@ func (p DNSMITM) processReq(clientAddr net.Addr, req []byte, network string) ([]
|
||||
|
||||
resp, err := p.requestDNS(req, network)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request")
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
|
||||
if p.ResponseHook != nil {
|
||||
var respMsg dns.Msg
|
||||
err = respMsg.Unpack(resp)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response")
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
modifiedResp, err := p.ResponseHook(clientAddr, reqMsg, respMsg, network)
|
||||
@ -123,12 +121,7 @@ func (p DNSMITM) processReq(clientAddr net.Addr, req []byte, network string) ([]
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (p DNSMITM) ListenTCP(ctx context.Context) error {
|
||||
addr, err := net.ResolveTCPAddr("tcp", "[::]:"+strconv.Itoa(int(p.ListenPort)))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve tcp address: %v", err)
|
||||
}
|
||||
|
||||
func (p DNSMITMProxy) ListenTCP(ctx context.Context, addr *net.TCPAddr) error {
|
||||
listener, err := net.ListenTCP("tcp", addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen tcp port: %v", err)
|
||||
@ -184,12 +177,7 @@ func (p DNSMITM) ListenTCP(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
func (p DNSMITM) ListenUDP(ctx context.Context) error {
|
||||
addr, err := net.ResolveUDPAddr("udp", "[::]:"+strconv.Itoa(int(p.ListenPort)))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve udp address: %v", err)
|
||||
}
|
||||
|
||||
func (p DNSMITMProxy) ListenUDP(ctx context.Context, addr *net.UDPAddr) error {
|
||||
conn, err := net.ListenUDP("udp", addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen udp port: %v", err)
|
||||
@ -226,14 +214,8 @@ func (p DNSMITM) ListenUDP(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
func New(listenPort uint16, targetDNSServerAddress string, targetDNSServerPort ...uint16) *DNSMITM {
|
||||
dnsMitm := &DNSMITM{
|
||||
ListenPort: listenPort,
|
||||
TargetDNSServerAddress: targetDNSServerAddress,
|
||||
TargetDNSServerPort: 53,
|
||||
func New() *DNSMITMProxy {
|
||||
return &DNSMITMProxy{
|
||||
TargetDNSServerPort: 53,
|
||||
}
|
||||
if len(targetDNSServerPort) > 0 {
|
||||
dnsMitm.TargetDNSServerPort = targetDNSServerPort[0]
|
||||
}
|
||||
return dnsMitm
|
||||
}
|
145
kvas2.go
145
kvas2.go
@ -6,19 +6,20 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"kvas2-go/dns-mitm"
|
||||
"kvas2-go/dns-mitm-proxy"
|
||||
"kvas2-go/models"
|
||||
"kvas2-go/netfilter-helper"
|
||||
"kvas2-go/records"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/vishvananda/netlink"
|
||||
"github.com/vishvananda/netlink/nl"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -39,10 +40,10 @@ type Config struct {
|
||||
type App struct {
|
||||
Config Config
|
||||
|
||||
DNSMITM *dnsMitm.DNSMITM
|
||||
DNSMITM *dnsMitmProxy.DNSMITMProxy
|
||||
NetfilterHelper4 *netfilterHelper.NetfilterHelper
|
||||
NetfilterHelper6 *netfilterHelper.NetfilterHelper
|
||||
Records *Records
|
||||
Records *records.Records
|
||||
Groups map[uuid.UUID]*Group
|
||||
|
||||
Link netlink.Link
|
||||
@ -54,7 +55,7 @@ type App struct {
|
||||
|
||||
func (a *App) handleLink(event netlink.LinkUpdate) {
|
||||
switch event.Change {
|
||||
case unix.IFF_UP:
|
||||
case 0x00000001:
|
||||
log.Debug().
|
||||
Str("interface", event.Link.Attrs().Name).
|
||||
Str("operstatestr", event.Attrs().OperState.String()).
|
||||
@ -94,7 +95,6 @@ func (a *App) start(ctx context.Context) (err error) {
|
||||
newCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
// TODO: Chan err
|
||||
errChan := make(chan error)
|
||||
|
||||
/*
|
||||
@ -102,16 +102,28 @@ func (a *App) start(ctx context.Context) (err error) {
|
||||
*/
|
||||
|
||||
go func() {
|
||||
err := a.DNSMITM.ListenUDP(newCtx)
|
||||
addr, err := net.ResolveUDPAddr("udp", "[::]:"+strconv.Itoa(int(a.Config.ListenDNSPort)))
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("failed to resolve udp address: %v", err)
|
||||
return
|
||||
}
|
||||
err = a.DNSMITM.ListenUDP(newCtx, addr)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("failed to serve DNS UDP proxy: %v", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
err := a.DNSMITM.ListenTCP(newCtx)
|
||||
addr, err := net.ResolveTCPAddr("tcp", "[::]:"+strconv.Itoa(int(a.Config.ListenDNSPort)))
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("failed to resolve tcp address: %v", err)
|
||||
return
|
||||
}
|
||||
err = a.DNSMITM.ListenTCP(newCtx, addr)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("failed to serve DNS TCP proxy: %v", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
@ -125,20 +137,14 @@ func (a *App) start(ctx context.Context) (err error) {
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to override DNS (IPv4): %v", err)
|
||||
}
|
||||
defer func() {
|
||||
// TODO: Handle error
|
||||
_ = a.dnsOverrider4.Disable()
|
||||
}()
|
||||
defer func() { _ = a.dnsOverrider4.Disable() }()
|
||||
|
||||
a.dnsOverrider6 = a.NetfilterHelper6.PortRemap(fmt.Sprintf("%sDNSOR", a.Config.ChainPrefix), 53, a.Config.ListenDNSPort, addrList)
|
||||
err = a.dnsOverrider6.Enable()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to override DNS (IPv6): %v", err)
|
||||
}
|
||||
defer func() {
|
||||
// TODO: Handle error
|
||||
_ = a.dnsOverrider6.Disable()
|
||||
}()
|
||||
defer func() { _ = a.dnsOverrider6.Disable() }()
|
||||
|
||||
/*
|
||||
Groups
|
||||
@ -152,7 +158,6 @@ func (a *App) start(ctx context.Context) (err error) {
|
||||
}
|
||||
defer func() {
|
||||
for _, group := range a.Groups {
|
||||
// TODO: Handle error
|
||||
_ = group.Disable()
|
||||
}
|
||||
}()
|
||||
@ -170,13 +175,16 @@ func (a *App) start(ctx context.Context) (err error) {
|
||||
return fmt.Errorf("error while serve UNIX socket: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
// TODO: Handle error
|
||||
_ = socket.Close()
|
||||
_ = os.Remove(socketPath)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
if newCtx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := socket.Accept()
|
||||
if err != nil {
|
||||
if !strings.Contains(err.Error(), "use of closed network connection") {
|
||||
@ -186,10 +194,7 @@ func (a *App) start(ctx context.Context) (err error) {
|
||||
}
|
||||
|
||||
go func(conn net.Conn) {
|
||||
defer func() {
|
||||
// TODO: Handle error
|
||||
_ = conn.Close()
|
||||
}()
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
buf := make([]byte, 1024)
|
||||
n, err := conn.Read(buf)
|
||||
@ -206,6 +211,12 @@ func (a *App) start(ctx context.Context) (err error) {
|
||||
log.Error().Err(err).Msg("error while fixing iptables after netfilter.d")
|
||||
}
|
||||
}
|
||||
if a.dnsOverrider6.Enabled {
|
||||
err = a.dnsOverrider6.PutIPTable(args[2])
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("error while fixing iptables after netfilter.d")
|
||||
}
|
||||
}
|
||||
for _, group := range a.Groups {
|
||||
if group.ifaceToIPSetNAT.Enabled {
|
||||
err := group.ifaceToIPSetNAT.PutIPTable(args[2])
|
||||
@ -240,7 +251,6 @@ func (a *App) start(ctx context.Context) (err error) {
|
||||
case event := <-linkUpdateChannel:
|
||||
a.handleLink(event)
|
||||
case err := <-errChan:
|
||||
close(errChan)
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
@ -288,22 +298,17 @@ func (a *App) AddGroup(group *models.Group) error {
|
||||
Group: group,
|
||||
iptables: a.NetfilterHelper4.IPTables,
|
||||
ipset: ipset,
|
||||
ifaceToIPSetNAT: a.NetfilterHelper4.IfaceToIPSet(fmt.Sprintf("%sR_%d", a.Config.ChainPrefix, group.ID), group.Interface, ipsetName, false),
|
||||
ifaceToIPSetNAT: a.NetfilterHelper4.IfaceToIPSet(fmt.Sprintf("%sR_%8x", a.Config.ChainPrefix, group.ID.ID()), group.Interface, ipsetName, false),
|
||||
}
|
||||
grp.ifaceToIPSetNAT.SoftwareMode = a.Config.UseSoftwareRouting
|
||||
a.Groups[grp.ID] = grp
|
||||
return a.SyncGroup(grp)
|
||||
}
|
||||
|
||||
func (a *App) SyncGroup(group *Group) error {
|
||||
processedDomains := make(map[string]struct{})
|
||||
newIpsetAddressesMap := make(map[string]time.Duration)
|
||||
now := time.Now()
|
||||
|
||||
oldIpsetAddresses, err := group.ListIPv4()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get old ipset list: %w", err)
|
||||
}
|
||||
|
||||
addresses := make(map[string]time.Duration)
|
||||
knownDomains := a.Records.ListKnownDomains()
|
||||
for _, domain := range group.Rules {
|
||||
if !domain.IsEnabled() {
|
||||
@ -315,26 +320,24 @@ func (a *App) SyncGroup(group *Group) error {
|
||||
continue
|
||||
}
|
||||
|
||||
cnames := a.Records.GetCNameRecords(domainName, true)
|
||||
if len(cnames) == 0 {
|
||||
continue
|
||||
}
|
||||
for _, cname := range cnames {
|
||||
processedDomains[cname] = struct{}{}
|
||||
}
|
||||
|
||||
addresses := a.Records.GetARecords(domainName)
|
||||
for _, address := range addresses {
|
||||
domainAddresses := a.Records.GetARecords(domainName)
|
||||
for _, address := range domainAddresses {
|
||||
ttl := now.Sub(address.Deadline)
|
||||
if oldTTL, ok := newIpsetAddressesMap[string(address.Address)]; !ok || ttl > oldTTL {
|
||||
newIpsetAddressesMap[string(address.Address)] = ttl
|
||||
if oldTTL, ok := addresses[string(address.Address)]; !ok || ttl > oldTTL {
|
||||
addresses[string(address.Address)] = ttl
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for addr, ttl := range newIpsetAddressesMap {
|
||||
if _, exists := oldIpsetAddresses[addr]; exists {
|
||||
currentAddresses, err := group.ListIPv4()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get old ipset list: %w", err)
|
||||
}
|
||||
|
||||
for addr, ttl := range addresses {
|
||||
// TODO: Check TTL
|
||||
if _, exists := currentAddresses[addr]; exists {
|
||||
continue
|
||||
}
|
||||
ip := net.IP(addr)
|
||||
@ -344,11 +347,16 @@ func (a *App) SyncGroup(group *Group) error {
|
||||
Str("address", ip.String()).
|
||||
Err(err).
|
||||
Msg("failed to add address")
|
||||
} else {
|
||||
log.Trace().
|
||||
Str("address", ip.String()).
|
||||
Err(err).
|
||||
Msg("add address")
|
||||
}
|
||||
}
|
||||
|
||||
for addr := range oldIpsetAddresses {
|
||||
if _, exists := newIpsetAddressesMap[addr]; exists {
|
||||
for addr := range currentAddresses {
|
||||
if _, ok := addresses[addr]; ok {
|
||||
continue
|
||||
}
|
||||
ip := net.IP(addr)
|
||||
@ -362,7 +370,7 @@ func (a *App) SyncGroup(group *Group) error {
|
||||
log.Trace().
|
||||
Str("address", ip.String()).
|
||||
Err(err).
|
||||
Msg("add address")
|
||||
Msg("del address")
|
||||
}
|
||||
}
|
||||
|
||||
@ -400,9 +408,9 @@ func (a *App) processARecord(aRecord dns.A) {
|
||||
ttlDuration = a.Config.MinimalTTL
|
||||
}
|
||||
|
||||
a.Records.AddARecord(aRecord.Hdr.Name, aRecord.A, ttlDuration)
|
||||
a.Records.AddARecord(aRecord.Hdr.Name[:len(aRecord.Hdr.Name)-1], aRecord.A, ttlDuration)
|
||||
|
||||
names := a.Records.GetCNameRecords(aRecord.Hdr.Name, true)
|
||||
names := a.Records.GetAliases(aRecord.Hdr.Name[:len(aRecord.Hdr.Name)-1])
|
||||
for _, group := range a.Groups {
|
||||
Rule:
|
||||
for _, domain := range group.Rules {
|
||||
@ -413,6 +421,7 @@ func (a *App) processARecord(aRecord dns.A) {
|
||||
if !domain.IsMatch(name) {
|
||||
continue
|
||||
}
|
||||
// TODO: Check already existed
|
||||
err := group.AddIPv4(aRecord.A, ttlDuration)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
@ -445,12 +454,12 @@ func (a *App) processCNameRecord(cNameRecord dns.CNAME) {
|
||||
ttlDuration = a.Config.MinimalTTL
|
||||
}
|
||||
|
||||
a.Records.AddCNameRecord(cNameRecord.Hdr.Name, cNameRecord.Target, ttlDuration)
|
||||
a.Records.AddCNameRecord(cNameRecord.Hdr.Name[:len(cNameRecord.Hdr.Name)-1], cNameRecord.Target, ttlDuration)
|
||||
|
||||
// TODO: Optimization
|
||||
now := time.Now()
|
||||
aRecords := a.Records.GetARecords(cNameRecord.Hdr.Name)
|
||||
names := a.Records.GetCNameRecords(cNameRecord.Hdr.Name, true)
|
||||
aRecords := a.Records.GetARecords(cNameRecord.Hdr.Name[:len(cNameRecord.Hdr.Name)-1])
|
||||
names := a.Records.GetAliases(cNameRecord.Hdr.Name[:len(cNameRecord.Hdr.Name)-1])
|
||||
for _, group := range a.Groups {
|
||||
Rule:
|
||||
for _, domain := range group.Rules {
|
||||
@ -496,12 +505,6 @@ func (a *App) handleMessage(msg dns.Msg) {
|
||||
for _, rr := range msg.Answer {
|
||||
a.handleRecord(rr)
|
||||
}
|
||||
for _, rr := range msg.Ns {
|
||||
a.handleRecord(rr)
|
||||
}
|
||||
for _, rr := range msg.Extra {
|
||||
a.handleRecord(rr)
|
||||
}
|
||||
}
|
||||
|
||||
func New(config Config) (*App, error) {
|
||||
@ -511,14 +514,10 @@ func New(config Config) (*App, error) {
|
||||
|
||||
app.Config = config
|
||||
|
||||
app.DNSMITM = dnsMitm.New(app.Config.ListenDNSPort, app.Config.TargetDNSServerAddress)
|
||||
app.DNSMITM = dnsMitmProxy.New()
|
||||
app.DNSMITM.TargetDNSServerAddress = app.Config.TargetDNSServerAddress
|
||||
app.DNSMITM.TargetDNSServerPort = 53
|
||||
app.DNSMITM.RequestHook = func(clientAddr net.Addr, reqMsg dns.Msg, network string) (*dns.Msg, *dns.Msg, error) {
|
||||
log.Debug().
|
||||
Str("network", network).
|
||||
Str("clientAddr", clientAddr.String()).
|
||||
Str("name", reqMsg.Question[0].Name).
|
||||
Msg("received DNS request")
|
||||
|
||||
// TODO: Need to understand why it not works in proxy mode
|
||||
if len(reqMsg.Question) == 1 && reqMsg.Question[0].Qtype == dns.TypePTR {
|
||||
respMsg := &dns.Msg{
|
||||
@ -530,10 +529,6 @@ func New(config Config) (*App, error) {
|
||||
},
|
||||
Question: reqMsg.Question,
|
||||
}
|
||||
log.Debug().
|
||||
Str("network", network).
|
||||
Str("clientAddr", clientAddr.String()).
|
||||
Msg("sending DNS response")
|
||||
return nil, respMsg, nil
|
||||
}
|
||||
|
||||
@ -551,20 +546,12 @@ func New(config Config) (*App, error) {
|
||||
}
|
||||
respMsg.Answer = respMsg.Answer[:idx]
|
||||
|
||||
if len(respMsg.Answer) != 0 {
|
||||
log.Debug().
|
||||
Str("network", network).
|
||||
Str("clientAddr", clientAddr.String()).
|
||||
Str("respMsg", respMsg.Answer[0].Header().Name).
|
||||
Msg("sending DNS response")
|
||||
}
|
||||
|
||||
app.handleMessage(respMsg)
|
||||
|
||||
return &respMsg, nil
|
||||
}
|
||||
|
||||
app.Records = NewRecords()
|
||||
app.Records = records.New()
|
||||
app.Groups = make(map[uuid.UUID]*Group, 0)
|
||||
|
||||
link, err := netlink.LinkByName(app.Config.LinkName)
|
||||
|
@ -179,7 +179,7 @@ func (r *IfaceToIPSet) ForceEnable() error {
|
||||
// IPTables rules
|
||||
err = r.PutIPTable("all")
|
||||
if err != nil {
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
// Mapping mark with table
|
||||
@ -194,7 +194,7 @@ func (r *IfaceToIPSet) ForceEnable() error {
|
||||
|
||||
err = r.IfaceHandle()
|
||||
if err != nil {
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
r.Enabled = true
|
||||
|
228
records.go
228
records.go
@ -1,228 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ARecord struct {
|
||||
Address net.IP
|
||||
Deadline time.Time
|
||||
}
|
||||
|
||||
func NewARecord(addr net.IP, deadline time.Time) *ARecord {
|
||||
return &ARecord{
|
||||
Address: addr,
|
||||
Deadline: deadline,
|
||||
}
|
||||
}
|
||||
|
||||
type CNameRecord struct {
|
||||
Alias string
|
||||
Deadline time.Time
|
||||
}
|
||||
|
||||
func NewCNameRecord(domainName string, deadline time.Time) *CNameRecord {
|
||||
return &CNameRecord{
|
||||
Alias: domainName,
|
||||
Deadline: deadline,
|
||||
}
|
||||
}
|
||||
|
||||
type Records struct {
|
||||
mutex sync.RWMutex
|
||||
ARecords map[string][]*ARecord
|
||||
CNameRecords map[string]*CNameRecord
|
||||
}
|
||||
|
||||
func (r *Records) cleanupARecords(now time.Time) {
|
||||
for name, aRecords := range r.ARecords {
|
||||
i := 0
|
||||
for _, aRecord := range aRecords {
|
||||
if now.After(aRecord.Deadline) {
|
||||
continue
|
||||
}
|
||||
aRecords[i] = aRecord
|
||||
i++
|
||||
}
|
||||
aRecords = aRecords[:i]
|
||||
if i == 0 {
|
||||
delete(r.ARecords, name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Records) cleanupCNameRecords(now time.Time) {
|
||||
for name, record := range r.CNameRecords {
|
||||
if now.After(record.Deadline) {
|
||||
delete(r.CNameRecords, name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Records) getAliasedDomain(now time.Time, domainName string) string {
|
||||
processedDomains := make(map[string]struct{})
|
||||
for {
|
||||
if _, processed := processedDomains[domainName]; processed {
|
||||
// Loop detected!
|
||||
return ""
|
||||
} else {
|
||||
processedDomains[domainName] = struct{}{}
|
||||
}
|
||||
|
||||
cname, ok := r.CNameRecords[domainName]
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
if now.After(cname.Deadline) {
|
||||
delete(r.CNameRecords, domainName)
|
||||
break
|
||||
}
|
||||
domainName = cname.Alias
|
||||
}
|
||||
return domainName
|
||||
}
|
||||
|
||||
func (r *Records) getActualARecords(now time.Time, domainName string) []*ARecord {
|
||||
aRecords, ok := r.ARecords[domainName]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
i := 0
|
||||
for _, aRecord := range aRecords {
|
||||
if now.After(aRecord.Deadline) {
|
||||
continue
|
||||
}
|
||||
aRecords[i] = aRecord
|
||||
i++
|
||||
}
|
||||
aRecords = aRecords[:i]
|
||||
if i == 0 {
|
||||
delete(r.ARecords, domainName)
|
||||
return nil
|
||||
}
|
||||
|
||||
return aRecords
|
||||
}
|
||||
|
||||
func (r *Records) getActualCNames(now time.Time, domainName string, fromEnd bool) []string {
|
||||
processedDomains := make(map[string]struct{})
|
||||
cNameList := make([]string, 0)
|
||||
if fromEnd {
|
||||
domainName = r.getAliasedDomain(now, domainName)
|
||||
cNameList = append(cNameList, domainName)
|
||||
}
|
||||
r.cleanupCNameRecords(now)
|
||||
for {
|
||||
if _, processed := processedDomains[domainName]; processed {
|
||||
// Loop detected!
|
||||
return nil
|
||||
} else {
|
||||
processedDomains[domainName] = struct{}{}
|
||||
}
|
||||
|
||||
found := false
|
||||
for aliasFrom, aliasTo := range r.CNameRecords {
|
||||
if aliasTo.Alias == domainName {
|
||||
cNameList = append(cNameList, aliasFrom)
|
||||
domainName = aliasFrom
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
break
|
||||
}
|
||||
}
|
||||
return cNameList
|
||||
}
|
||||
|
||||
func (r *Records) Cleanup() {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
now := time.Now()
|
||||
r.cleanupARecords(now)
|
||||
r.cleanupCNameRecords(now)
|
||||
}
|
||||
|
||||
func (r *Records) GetCNameRecords(domainName string, fromEnd bool) []string {
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
now := time.Now()
|
||||
|
||||
return r.getActualCNames(now, domainName, fromEnd)
|
||||
}
|
||||
|
||||
func (r *Records) GetARecords(domainName string) []*ARecord {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
now := time.Now()
|
||||
|
||||
return r.getActualARecords(now, r.getAliasedDomain(now, domainName))
|
||||
}
|
||||
|
||||
func (r *Records) AddCNameRecord(domainName string, cName string, ttl time.Duration) {
|
||||
if domainName == cName {
|
||||
// Can't assing to yourself
|
||||
return
|
||||
}
|
||||
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
now := time.Now()
|
||||
|
||||
delete(r.ARecords, domainName)
|
||||
r.CNameRecords[domainName] = NewCNameRecord(cName, now.Add(ttl))
|
||||
}
|
||||
|
||||
func (r *Records) AddARecord(domainName string, addr net.IP, ttl time.Duration) {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
now := time.Now()
|
||||
|
||||
delete(r.CNameRecords, domainName)
|
||||
if _, ok := r.ARecords[domainName]; !ok {
|
||||
r.ARecords[domainName] = make([]*ARecord, 0)
|
||||
}
|
||||
for _, aRecord := range r.ARecords[domainName] {
|
||||
if bytes.Compare(aRecord.Address, addr) == 0 {
|
||||
aRecord.Deadline = now.Add(ttl)
|
||||
return
|
||||
}
|
||||
}
|
||||
r.ARecords[domainName] = append(r.ARecords[domainName], NewARecord(addr, now.Add(ttl)))
|
||||
}
|
||||
|
||||
func (r *Records) ListKnownDomains() []string {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
now := time.Now()
|
||||
r.cleanupARecords(now)
|
||||
r.cleanupCNameRecords(now)
|
||||
|
||||
domains := map[string]struct{}{}
|
||||
for name, _ := range r.ARecords {
|
||||
domains[name] = struct{}{}
|
||||
}
|
||||
for name, _ := range r.CNameRecords {
|
||||
domains[name] = struct{}{}
|
||||
}
|
||||
|
||||
domainsList := make([]string, len(domains))
|
||||
i := 0
|
||||
for name, _ := range domains {
|
||||
domainsList[i] = name
|
||||
i++
|
||||
}
|
||||
return domainsList
|
||||
}
|
||||
|
||||
func NewRecords() *Records {
|
||||
return &Records{
|
||||
ARecords: make(map[string][]*ARecord),
|
||||
CNameRecords: make(map[string]*CNameRecord),
|
||||
}
|
||||
}
|
167
records/records.go
Normal file
167
records/records.go
Normal file
@ -0,0 +1,167 @@
|
||||
package records
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ARecord struct {
|
||||
Address net.IP
|
||||
Deadline time.Time
|
||||
}
|
||||
|
||||
type CNameRecord struct {
|
||||
Alias string
|
||||
Deadline time.Time
|
||||
}
|
||||
|
||||
type Records struct {
|
||||
mux sync.RWMutex
|
||||
records map[string]interface{}
|
||||
}
|
||||
|
||||
func (r *Records) AddCNameRecord(domainName, alias string, ttl time.Duration) {
|
||||
if domainName == alias {
|
||||
return
|
||||
}
|
||||
|
||||
r.mux.Lock()
|
||||
r.records[domainName] = &CNameRecord{
|
||||
Alias: alias,
|
||||
Deadline: time.Now().Add(ttl),
|
||||
}
|
||||
r.mux.Unlock()
|
||||
}
|
||||
|
||||
func (r *Records) AddARecord(domainName string, addr net.IP, ttl time.Duration) {
|
||||
r.mux.Lock()
|
||||
defer r.mux.Unlock()
|
||||
|
||||
deadline := time.Now().Add(ttl)
|
||||
|
||||
aRecords, _ := r.records[domainName].([]*ARecord)
|
||||
for _, aRecord := range aRecords {
|
||||
if bytes.Compare(aRecord.Address, addr) != 0 {
|
||||
continue
|
||||
}
|
||||
aRecord.Deadline = deadline
|
||||
return
|
||||
}
|
||||
|
||||
r.records[domainName] = append(aRecords, &ARecord{
|
||||
Address: addr,
|
||||
Deadline: deadline,
|
||||
})
|
||||
}
|
||||
|
||||
func (r *Records) GetAliases(domainName string) []string {
|
||||
r.mux.Lock()
|
||||
defer r.mux.Unlock()
|
||||
r.cleanupRecords()
|
||||
|
||||
domains := make(map[string]struct{})
|
||||
domains[domainName] = struct{}{}
|
||||
|
||||
for {
|
||||
var addedNew bool
|
||||
for name, aRecord := range r.records {
|
||||
if _, ok := domains[name]; ok {
|
||||
continue
|
||||
}
|
||||
cname, ok := aRecord.(*CNameRecord)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if _, ok = domains[cname.Alias]; !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
domains[name] = struct{}{}
|
||||
addedNew = true
|
||||
}
|
||||
if !addedNew {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
domainList := make([]string, len(domains))
|
||||
idx := 0
|
||||
for name, _ := range domains {
|
||||
domainList[idx] = name
|
||||
idx++
|
||||
}
|
||||
|
||||
return domainList
|
||||
}
|
||||
|
||||
func (r *Records) GetARecords(domainName string) []*ARecord {
|
||||
r.mux.Lock()
|
||||
defer r.mux.Unlock()
|
||||
r.cleanupRecords()
|
||||
|
||||
loopDetect := make(map[string]struct{})
|
||||
loopDetect[domainName] = struct{}{}
|
||||
for {
|
||||
switch v := r.records[domainName].(type) {
|
||||
case *CNameRecord:
|
||||
if _, ok := loopDetect[v.Alias]; ok {
|
||||
return nil
|
||||
}
|
||||
domainName = v.Alias
|
||||
loopDetect[v.Alias] = struct{}{}
|
||||
case []*ARecord:
|
||||
return v
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Records) ListKnownDomains() []string {
|
||||
r.mux.Lock()
|
||||
defer r.mux.Unlock()
|
||||
r.cleanupRecords()
|
||||
|
||||
domainsList := make([]string, len(r.records))
|
||||
i := 0
|
||||
for name, _ := range r.records {
|
||||
domainsList[i] = name
|
||||
i++
|
||||
}
|
||||
return domainsList
|
||||
}
|
||||
|
||||
func (r *Records) cleanupRecords() {
|
||||
now := time.Now()
|
||||
for name, records := range r.records {
|
||||
switch v := records.(type) {
|
||||
case []*ARecord:
|
||||
idx := 0
|
||||
for _, aRecord := range v {
|
||||
if now.After(aRecord.Deadline) {
|
||||
continue
|
||||
}
|
||||
v[idx] = aRecord
|
||||
idx++
|
||||
}
|
||||
if idx == 0 {
|
||||
delete(r.records, name)
|
||||
break
|
||||
}
|
||||
r.records[name] = v[:idx]
|
||||
case *CNameRecord:
|
||||
if !now.After(v.Deadline) {
|
||||
continue
|
||||
}
|
||||
delete(r.records, name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func New() *Records {
|
||||
return &Records{
|
||||
records: make(map[string]interface{}),
|
||||
}
|
||||
}
|
109
records/records_test.go
Normal file
109
records/records_test.go
Normal file
@ -0,0 +1,109 @@
|
||||
package records
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestLoop(t *testing.T) {
|
||||
r := New()
|
||||
r.AddCNameRecord("1", "2", time.Minute)
|
||||
r.AddCNameRecord("2", "1", time.Minute)
|
||||
if r.GetARecords("1") != nil {
|
||||
t.Fatal("loop detected")
|
||||
}
|
||||
if r.GetARecords("2") != nil {
|
||||
t.Fatal("loop detected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCName(t *testing.T) {
|
||||
r := New()
|
||||
r.AddARecord("example.com", []byte{1, 2, 3, 4}, time.Minute)
|
||||
r.AddCNameRecord("gateway.example.com", "example.com", time.Minute)
|
||||
records := r.GetARecords("gateway.example.com")
|
||||
if records == nil {
|
||||
t.Fatal("no records")
|
||||
}
|
||||
if bytes.Compare(records[0].Address, []byte{1, 2, 3, 4}) != 0 {
|
||||
t.Fatal("cname mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestA(t *testing.T) {
|
||||
r := New()
|
||||
r.AddARecord("example.com", []byte{1, 2, 3, 4}, time.Minute)
|
||||
records := r.GetARecords("example.com")
|
||||
if records == nil {
|
||||
t.Fatal("no records")
|
||||
}
|
||||
if bytes.Compare(records[0].Address, []byte{1, 2, 3, 4}) != 0 {
|
||||
t.Fatal("cname mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeprecated(t *testing.T) {
|
||||
r := New()
|
||||
r.AddARecord("example.com", []byte{1, 2, 3, 4}, -time.Minute)
|
||||
records := r.GetARecords("example.com")
|
||||
if records != nil {
|
||||
t.Fatal("deprecated records")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNotExistedA(t *testing.T) {
|
||||
r := New()
|
||||
records := r.GetARecords("example.com")
|
||||
if records != nil {
|
||||
t.Fatal("not existed records")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNotExistedCNameAlias(t *testing.T) {
|
||||
r := New()
|
||||
r.AddCNameRecord("gateway.example.com", "example.com", time.Minute)
|
||||
records := r.GetARecords("gateway.example.com")
|
||||
if records != nil {
|
||||
t.Fatal("not existed records")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplacing(t *testing.T) {
|
||||
r := New()
|
||||
r.AddCNameRecord("gateway.example.com", "example.com", time.Minute)
|
||||
r.AddARecord("gateway.example.com", []byte{1, 2, 3, 4}, time.Minute)
|
||||
records := r.GetARecords("gateway.example.com")
|
||||
if bytes.Compare(records[0].Address, []byte{1, 2, 3, 4}) != 0 {
|
||||
t.Fatal("mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAliases(t *testing.T) {
|
||||
r := New()
|
||||
r.AddARecord("1", []byte{1, 2, 3, 4}, time.Minute)
|
||||
r.AddCNameRecord("2", "1", time.Minute)
|
||||
r.AddCNameRecord("3", "2", time.Minute)
|
||||
r.AddCNameRecord("4", "2", time.Minute)
|
||||
r.AddCNameRecord("5", "1", time.Minute)
|
||||
aliases := r.GetAliases("1")
|
||||
if aliases == nil {
|
||||
t.Fatal("no aliases")
|
||||
}
|
||||
if !slices.Contains(aliases, "1") {
|
||||
t.Fatal("no 1")
|
||||
}
|
||||
if !slices.Contains(aliases, "2") {
|
||||
t.Fatal("no 2")
|
||||
}
|
||||
if !slices.Contains(aliases, "3") {
|
||||
t.Fatal("no 3")
|
||||
}
|
||||
if !slices.Contains(aliases, "4") {
|
||||
t.Fatal("no 4")
|
||||
}
|
||||
if !slices.Contains(aliases, "5") {
|
||||
t.Fatal("no 5")
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user