From 912f56246ac037e62b79577523ef481dd14d7c5d Mon Sep 17 00:00:00 2001 From: Vladimir Avtsenov Date: Tue, 22 Oct 2024 23:23:07 +0300 Subject: [PATCH] fix foreign dns blocking --- kvas2.go | 19 +++++++++++++++++-- main.go | 1 + netfilter-helper/port-remap.go | 27 +++++++++++++++++++++------ 3 files changed, 39 insertions(+), 8 deletions(-) diff --git a/kvas2.go b/kvas2.go index 4c5387a..490f010 100644 --- a/kvas2.go +++ b/kvas2.go @@ -15,6 +15,7 @@ import ( "github.com/rs/zerolog/log" "github.com/vishvananda/netlink" + "github.com/vishvananda/netlink/nl" ) var ( @@ -26,6 +27,7 @@ type Config struct { MinimalTTL time.Duration ChainPrefix string IpSetPrefix string + LinkName string TargetDNSServerAddress string ListenPort uint16 UseSoftwareRouting bool @@ -40,6 +42,8 @@ type App struct { Records *Records Groups map[int]*Group + Link netlink.Link + isRunning bool dnsOverrider4 *netfilterHelper.PortRemap dnsOverrider6 *netfilterHelper.PortRemap @@ -92,7 +96,12 @@ func (a *App) listen(ctx context.Context) (err error) { } }() - a.dnsOverrider4 = a.NetfilterHelper4.PortRemap(fmt.Sprintf("%sDNSOR", a.Config.ChainPrefix), 53, a.Config.ListenPort) + addrList, err := netlink.AddrList(a.Link, nl.FAMILY_ALL) + if err != nil { + return fmt.Errorf("failed to addrList address: %w", err) + } + + a.dnsOverrider4 = a.NetfilterHelper4.PortRemap(fmt.Sprintf("%sDNSOR", a.Config.ChainPrefix), 53, a.Config.ListenPort, addrList) err = a.dnsOverrider4.Enable() if err != nil { return fmt.Errorf("failed to override DNS (IPv4): %v", err) @@ -102,7 +111,7 @@ func (a *App) listen(ctx context.Context) (err error) { _ = a.dnsOverrider4.Disable() }() - a.dnsOverrider6 = a.NetfilterHelper6.PortRemap(fmt.Sprintf("%sDNSOR", a.Config.ChainPrefix), 53, a.Config.ListenPort) + a.dnsOverrider6 = a.NetfilterHelper6.PortRemap(fmt.Sprintf("%sDNSOR", a.Config.ChainPrefix), 53, a.Config.ListenPort, addrList) err = a.dnsOverrider6.Enable() if err != nil { return fmt.Errorf("failed to override DNS (IPv6): %v", err) @@ -474,6 +483,12 @@ func New(config Config) (*App, error) { app.Config = config + link, err := netlink.LinkByName(app.Config.LinkName) + if err != nil { + return nil, fmt.Errorf("failed to find link %s: %w", app.Config.LinkName, err) + } + app.Link = link + app.DNSProxy = dnsProxy.New(app.Config.ListenPort, app.Config.TargetDNSServerAddress) app.DNSProxy.MsgHandler = app.handleMessage diff --git a/main.go b/main.go index 1a83e5d..5242728 100644 --- a/main.go +++ b/main.go @@ -18,6 +18,7 @@ func main() { MinimalTTL: time.Hour, ChainPrefix: "KVAS2_", IpSetPrefix: "kvas2_", + LinkName: "br0", TargetDNSServerAddress: "127.0.0.1:53", ListenPort: 7548, }) diff --git a/netfilter-helper/port-remap.go b/netfilter-helper/port-remap.go index d81d6fb..0e3725d 100644 --- a/netfilter-helper/port-remap.go +++ b/netfilter-helper/port-remap.go @@ -2,13 +2,17 @@ package netfilterHelper import ( "fmt" - "github.com/coreos/go-iptables/iptables" + "net" "strconv" + + "github.com/coreos/go-iptables/iptables" + "github.com/vishvananda/netlink" ) type PortRemap struct { IPTables *iptables.IPTables ChainName string + Addresses []netlink.Addr From uint16 To uint16 @@ -22,10 +26,20 @@ func (r *PortRemap) PutIPTable(table string) error { return fmt.Errorf("failed to clear chain: %w", err) } - // TODO: Add `-d ` - err = r.IPTables.AppendUnique("nat", r.ChainName, "-p", "udp", "--dport", strconv.Itoa(int(r.From)), "-j", "DNAT", "--to-destination", fmt.Sprintf(":%d", r.To)) - if err != nil { - return fmt.Errorf("failed to create rule: %w", err) + for _, addr := range r.Addresses { + var addrIP net.IP + iptablesProtocol := r.IPTables.Proto() + if (iptablesProtocol == iptables.ProtocolIPv4 && len(addr.IP) == net.IPv4len) || (iptablesProtocol == iptables.ProtocolIPv6 && len(addr.IP) == net.IPv6len) { + addrIP = addr.IP + } + if addrIP == nil { + continue + } + + err = r.IPTables.AppendUnique("nat", r.ChainName, "-p", "udp", "-d", addrIP.String(), "--dport", strconv.Itoa(int(r.From)), "-j", "DNAT", "--to-destination", fmt.Sprintf(":%d", r.To)) + if err != nil { + return fmt.Errorf("failed to create rule: %w", err) + } } err = r.IPTables.InsertUnique("nat", "PREROUTING", 1, "-j", r.ChainName) @@ -78,10 +92,11 @@ func (r *PortRemap) Enable() error { return nil } -func (nh *NetfilterHelper) PortRemap(name string, from, to uint16) *PortRemap { +func (nh *NetfilterHelper) PortRemap(name string, from, to uint16, addr []netlink.Addr) *PortRemap { return &PortRemap{ IPTables: nh.IPTables, ChainName: name, + Addresses: addr, From: from, To: to, }