config extension

This commit is contained in:
Vladimir Avtsenov 2025-02-11 23:05:24 +03:00
parent fdb1038ba9
commit d5b4645ec0
4 changed files with 18 additions and 24 deletions

View File

@ -23,9 +23,8 @@ type Group struct {
ipsetToLink *netfilterHelper.IPSetToLink ipsetToLink *netfilterHelper.IPSetToLink
} }
func (g *Group) AddIP(address net.IP, ttl time.Duration) error { func (g *Group) AddIP(address net.IP, ttl uint32) error {
ttlSeconds := uint32(ttl.Seconds()) return g.ipset.AddIP(address, &ttl)
return g.ipset.AddIP(address, &ttlSeconds)
} }
func (g *Group) DelIP(address net.IP) error { func (g *Group) DelIP(address net.IP) error {
@ -90,7 +89,7 @@ func (g *Group) Disable() []error {
func (g *Group) Sync(records *records.Records) error { func (g *Group) Sync(records *records.Records) error {
now := time.Now() now := time.Now()
addresses := make(map[string]time.Duration) addresses := make(map[string]uint32)
knownDomains := records.ListKnownDomains() knownDomains := records.ListKnownDomains()
for _, domain := range g.Rules { for _, domain := range g.Rules {
if !domain.IsEnabled() { if !domain.IsEnabled() {
@ -104,7 +103,7 @@ func (g *Group) Sync(records *records.Records) error {
domainAddresses := records.GetARecords(domainName) domainAddresses := records.GetARecords(domainName)
for _, address := range domainAddresses { for _, address := range domainAddresses {
ttl := now.Sub(address.Deadline) ttl := uint32(now.Sub(address.Deadline).Seconds())
if oldTTL, ok := addresses[string(address.Address)]; !ok || ttl > oldTTL { if oldTTL, ok := addresses[string(address.Address)]; !ok || ttl > oldTTL {
addresses[string(address.Address)] = ttl addresses[string(address.Address)] = ttl
} }

View File

@ -29,11 +29,12 @@ var (
) )
type Config struct { type Config struct {
MinimalTTL time.Duration AdditionalTTL uint32
ChainPrefix string ChainPrefix string
IpSetPrefix string IpSetPrefix string
LinkName string LinkName string
TargetDNSServerAddress string TargetDNSServerAddress string
TargetDNSServerPort uint16
ListenDNSPort uint16 ListenDNSPort uint16
} }
@ -316,10 +317,7 @@ func (a *App) processARecord(aRecord dns.A) {
Int("ttl", int(aRecord.Hdr.Ttl)). Int("ttl", int(aRecord.Hdr.Ttl)).
Msg("processing a record") Msg("processing a record")
ttlDuration := time.Duration(aRecord.Hdr.Ttl) * time.Second ttlDuration := aRecord.Hdr.Ttl + a.Config.AdditionalTTL
if ttlDuration < a.Config.MinimalTTL {
ttlDuration = a.Config.MinimalTTL
}
a.Records.AddARecord(aRecord.Hdr.Name[:len(aRecord.Hdr.Name)-1], aRecord.A, ttlDuration) a.Records.AddARecord(aRecord.Hdr.Name[:len(aRecord.Hdr.Name)-1], aRecord.A, ttlDuration)
@ -362,10 +360,7 @@ func (a *App) processCNameRecord(cNameRecord dns.CNAME) {
Int("ttl", int(cNameRecord.Hdr.Ttl)). Int("ttl", int(cNameRecord.Hdr.Ttl)).
Msg("processing cname record") Msg("processing cname record")
ttlDuration := time.Duration(cNameRecord.Hdr.Ttl) * time.Second ttlDuration := cNameRecord.Hdr.Ttl + a.Config.AdditionalTTL
if ttlDuration < a.Config.MinimalTTL {
ttlDuration = a.Config.MinimalTTL
}
a.Records.AddCNameRecord(cNameRecord.Hdr.Name[:len(cNameRecord.Hdr.Name)-1], cNameRecord.Target, ttlDuration) a.Records.AddCNameRecord(cNameRecord.Hdr.Name[:len(cNameRecord.Hdr.Name)-1], cNameRecord.Target, ttlDuration)
@ -384,7 +379,7 @@ func (a *App) processCNameRecord(cNameRecord dns.CNAME) {
continue continue
} }
for _, aRecord := range aRecords { for _, aRecord := range aRecords {
err := group.AddIP(aRecord.Address, now.Sub(aRecord.Deadline)) err := group.AddIP(aRecord.Address, uint32(now.Sub(aRecord.Deadline).Seconds()))
if err != nil { if err != nil {
log.Error(). log.Error().
Str("address", aRecord.Address.String()). Str("address", aRecord.Address.String()).
@ -429,7 +424,7 @@ func New(config Config) (*App, error) {
app.DNSMITM = dnsMitmProxy.New() app.DNSMITM = dnsMitmProxy.New()
app.DNSMITM.TargetDNSServerAddress = app.Config.TargetDNSServerAddress app.DNSMITM.TargetDNSServerAddress = app.Config.TargetDNSServerAddress
app.DNSMITM.TargetDNSServerPort = 53 app.DNSMITM.TargetDNSServerPort = app.Config.TargetDNSServerPort
app.DNSMITM.RequestHook = func(clientAddr net.Addr, reqMsg dns.Msg, network string) (*dns.Msg, *dns.Msg, error) { app.DNSMITM.RequestHook = func(clientAddr net.Addr, reqMsg dns.Msg, network string) (*dns.Msg, *dns.Msg, error) {
// TODO: Need to understand why it not works in proxy mode // TODO: Need to understand why it not works in proxy mode
if len(reqMsg.Question) == 1 && reqMsg.Question[0].Qtype == dns.TypePTR { if len(reqMsg.Question) == 1 && reqMsg.Question[0].Qtype == dns.TypePTR {

View File

@ -2,12 +2,11 @@ package main
import ( import (
"context" "context"
"github.com/rs/zerolog"
"os" "os"
"os/signal" "os/signal"
"syscall" "syscall"
"time"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
@ -15,12 +14,13 @@ func main() {
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr})
app, err := New(Config{ app, err := New(Config{
MinimalTTL: time.Hour, AdditionalTTL: 216000, // 1 hour
ChainPrefix: "KVAS2_", ChainPrefix: "KVAS2_",
IpSetPrefix: "kvas2_", IpSetPrefix: "kvas2_",
LinkName: "br0", LinkName: "br0",
TargetDNSServerAddress: "127.0.0.1", TargetDNSServerAddress: "127.0.0.1",
ListenDNSPort: 7553, TargetDNSServerPort: 53,
ListenDNSPort: 3553,
}) })
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("failed to initialize application") log.Fatal().Err(err).Msg("failed to initialize application")

View File

@ -22,7 +22,7 @@ type Records struct {
records map[string]interface{} records map[string]interface{}
} }
func (r *Records) AddCNameRecord(domainName, alias string, ttl time.Duration) { func (r *Records) AddCNameRecord(domainName, alias string, ttl uint32) {
if domainName == alias { if domainName == alias {
return return
} }
@ -30,16 +30,16 @@ func (r *Records) AddCNameRecord(domainName, alias string, ttl time.Duration) {
r.mux.Lock() r.mux.Lock()
r.records[domainName] = &CNameRecord{ r.records[domainName] = &CNameRecord{
Alias: alias, Alias: alias,
Deadline: time.Now().Add(ttl), Deadline: time.Now().Add(time.Duration(ttl) * time.Second),
} }
r.mux.Unlock() r.mux.Unlock()
} }
func (r *Records) AddARecord(domainName string, addr net.IP, ttl time.Duration) { func (r *Records) AddARecord(domainName string, addr net.IP, ttl uint32) {
r.mux.Lock() r.mux.Lock()
defer r.mux.Unlock() defer r.mux.Unlock()
deadline := time.Now().Add(ttl) deadline := time.Now().Add(time.Duration(ttl) * time.Second)
aRecords, _ := r.records[domainName].([]*ARecord) aRecords, _ := r.records[domainName].([]*ARecord)
for _, aRecord := range aRecords { for _, aRecord := range aRecords {