Compare commits

...

5 Commits

Author SHA1 Message Date
ff6ab7b859 refactor PortRemap 2025-02-11 15:53:58 +03:00
b537007c9a smol refactor 2025-02-11 15:53:09 +03:00
0f4820f499 fix name 2025-02-11 15:37:38 +03:00
5fd28ae005 refactor IPSet 2025-02-11 15:29:26 +03:00
066eeb0ab7 refactor CleanIPTables 2025-02-11 15:22:08 +03:00
6 changed files with 79 additions and 78 deletions

View File

@ -21,17 +21,17 @@ type Group struct {
ipsetToLink *netfilterHelper.IPSetToLink
}
func (g *Group) AddIPv4(address net.IP, ttl time.Duration) error {
func (g *Group) AddIP(address net.IP, ttl time.Duration) error {
ttlSeconds := uint32(ttl.Seconds())
return g.ipset.AddIP(address, &ttlSeconds)
}
func (g *Group) DelIPv4(address net.IP) error {
return g.ipset.Del(address)
func (g *Group) DelIP(address net.IP) error {
return g.ipset.DelIP(address)
}
func (g *Group) ListIPv4() (map[string]*uint32, error) {
return g.ipset.List()
func (g *Group) ListIP() (map[string]*uint32, error) {
return g.ipset.ListIPs()
}
func (g *Group) Enable() error {

View File

@ -204,18 +204,14 @@ func (a *App) start(ctx context.Context) (err error) {
args := strings.Split(string(buf[:n]), ":")
if len(args) == 3 && args[0] == "netfilter.d" {
log.Debug().Str("table", args[2]).Msg("netfilter.d event")
if a.dnsOverrider4.Enabled {
err := a.dnsOverrider4.PutIPTable(args[2])
err = a.dnsOverrider4.NetfilerDHook(args[2])
if err != nil {
log.Error().Err(err).Msg("error while fixing iptables after netfilter.d")
}
}
if a.dnsOverrider6.Enabled {
err = a.dnsOverrider6.PutIPTable(args[2])
err = a.dnsOverrider6.NetfilerDHook(args[2])
if err != nil {
log.Error().Err(err).Msg("error while fixing iptables after netfilter.d")
}
}
for _, group := range a.Groups {
err := group.ipsetToLink.NetfilerDHook(args[2])
if err != nil {
@ -295,7 +291,7 @@ func (a *App) AddGroup(group *models.Group) error {
Group: group,
iptables: a.NetfilterHelper4.IPTables,
ipset: ipset,
ipsetToLink: a.NetfilterHelper4.IfaceToIPSet(fmt.Sprintf("%sR_%8x", a.Config.ChainPrefix, group.ID.ID()), group.Interface, ipsetName, false),
ipsetToLink: a.NetfilterHelper4.IPSetToLink(fmt.Sprintf("%sR_%8x", a.Config.ChainPrefix, group.ID.ID()), group.Interface, ipsetName, false),
}
a.Groups[grp.ID] = grp
return a.SyncGroup(grp)
@ -326,7 +322,7 @@ func (a *App) SyncGroup(group *Group) error {
}
}
currentAddresses, err := group.ListIPv4()
currentAddresses, err := group.ListIP()
if err != nil {
return fmt.Errorf("failed to get old ipset list: %w", err)
}
@ -337,7 +333,7 @@ func (a *App) SyncGroup(group *Group) error {
continue
}
ip := net.IP(addr)
err = group.AddIPv4(ip, ttl)
err = group.AddIP(ip, ttl)
if err != nil {
log.Error().
Str("address", ip.String()).
@ -356,7 +352,7 @@ func (a *App) SyncGroup(group *Group) error {
continue
}
ip := net.IP(addr)
err = group.DelIPv4(ip)
err = group.DelIP(ip)
if err != nil {
log.Error().
Str("address", ip.String()).
@ -418,7 +414,7 @@ func (a *App) processARecord(aRecord dns.A) {
continue
}
// TODO: Check already existed
err := group.AddIPv4(aRecord.A, ttlDuration)
err := group.AddIP(aRecord.A, ttlDuration)
if err != nil {
log.Error().
Str("address", aRecord.A.String()).
@ -467,7 +463,7 @@ func (a *App) processCNameRecord(cNameRecord dns.CNAME) {
continue
}
for _, aRecord := range aRecords {
err := group.AddIPv4(aRecord.Address, now.Sub(aRecord.Deadline))
err := group.AddIP(aRecord.Address, now.Sub(aRecord.Deadline))
if err != nil {
log.Error().
Str("address", aRecord.Address.String()).
@ -561,7 +557,7 @@ func New(config Config) (*App, error) {
return nil, fmt.Errorf("netfilter helper init fail: %w", err)
}
app.NetfilterHelper4 = nh4
err = app.NetfilterHelper4.ClearIPTables(app.Config.ChainPrefix)
err = app.NetfilterHelper4.CleanIPTables(app.Config.ChainPrefix)
if err != nil {
return nil, fmt.Errorf("failed to clear iptables: %w", err)
}
@ -571,7 +567,7 @@ func New(config Config) (*App, error) {
return nil, fmt.Errorf("netfilter helper init fail: %w", err)
}
app.NetfilterHelper6 = nh6
err = app.NetfilterHelper6.ClearIPTables(app.Config.ChainPrefix)
err = app.NetfilterHelper6.CleanIPTables(app.Config.ChainPrefix)
if err != nil {
return nil, fmt.Errorf("failed to clear iptables: %w", err)
}

View File

@ -76,9 +76,8 @@ func (r *IPSetToLink) insertIPTablesRules(table string) error {
func (r *IPSetToLink) deleteIPTablesRules() []error {
var errs []error
var err error
err = r.IPTables.DeleteIfExists("mangle", "PREROUTING", "-m", "set", "--match-set", r.IPSetName, "dst", "-j", r.ChainName)
err := r.IPTables.DeleteIfExists("mangle", "PREROUTING", "-m", "set", "--match-set", r.IPSetName, "dst", "-j", r.ChainName)
if err != nil {
errs = append(errs, fmt.Errorf("failed to unlinking chain: %w", err))
}
@ -280,7 +279,7 @@ func (r *IPSetToLink) LinkUpdateHook() error {
return r.insertIPRoute()
}
func (nh *NetfilterHelper) IfaceToIPSet(name string, ifaceName, ipsetName string, softwareMode bool) *IPSetToLink {
func (nh *NetfilterHelper) IPSetToLink(name string, ifaceName, ipsetName string, softwareMode bool) *IPSetToLink {
return &IPSetToLink{
IPTables: nh.IPTables,
ChainName: name,

View File

@ -2,9 +2,10 @@ package netfilterHelper
import (
"fmt"
"github.com/vishvananda/netlink"
"net"
"os"
"github.com/vishvananda/netlink"
)
type IPSet struct {
@ -23,7 +24,7 @@ func (r *IPSet) AddIP(addr net.IP, timeout *uint32) error {
return nil
}
func (r *IPSet) Del(addr net.IP) error {
func (r *IPSet) DelIP(addr net.IP) error {
err := netlink.IpsetDel(r.SetName, &netlink.IPSetEntry{
IP: addr,
})
@ -33,7 +34,7 @@ func (r *IPSet) Del(addr net.IP) error {
return nil
}
func (r *IPSet) List() (map[string]*uint32, error) {
func (r *IPSet) ListIPs() (map[string]*uint32, error) {
list, err := netlink.IpsetList(r.SetName)
if err != nil {
return nil, err
@ -62,9 +63,8 @@ func (nh *NetfilterHelper) IPSet(name string) (*IPSet, error) {
return nil, err
}
defaultTimeout := uint32(300)
err = netlink.IpsetCreate(ipset.SetName, "hash:net", netlink.IpsetCreateOptions{
Timeout: &defaultTimeout,
Timeout: func(i uint32) *uint32 { return &i }(300),
})
if err != nil {
return nil, fmt.Errorf("failed to create ipset: %w", err)

View File

@ -5,11 +5,9 @@ import (
"strings"
)
func (nh *NetfilterHelper) ClearIPTables(chainPrefix string) error {
func (nh *NetfilterHelper) CleanIPTables(chainPrefix string) error {
jumpToChainPrefix := fmt.Sprintf("-j %s", chainPrefix)
tableList := []string{"nat", "mangle", "filter"}
for _, table := range tableList {
for _, table := range []string{"nat", "mangle", "filter"} {
chainListToDelete := make([]string, 0)
chains, err := nh.IPTables.ListChains(table)
@ -29,15 +27,8 @@ func (nh *NetfilterHelper) ClearIPTables(chainPrefix string) error {
}
for _, rule := range rules {
ruleSlice := strings.Split(rule, " ")
if len(ruleSlice) < 2 || ruleSlice[0] != "-A" || ruleSlice[1] != chain {
// TODO: Warn
continue
}
ruleSlice = ruleSlice[2:]
if strings.Contains(strings.Join(ruleSlice, " "), jumpToChainPrefix) {
err := nh.IPTables.Delete(table, chain, ruleSlice...)
if strings.Contains(rule, jumpToChainPrefix) {
err = nh.IPTables.Delete(table, chain, rule)
if err != nil {
return fmt.Errorf("rule deletion error: %w", err)
}
@ -46,7 +37,7 @@ func (nh *NetfilterHelper) ClearIPTables(chainPrefix string) error {
}
for _, chain := range chainListToDelete {
err := nh.IPTables.ClearAndDeleteChain(table, chain)
err = nh.IPTables.ClearAndDeleteChain(table, chain)
if err != nil {
return fmt.Errorf("deleting chain error: %w", err)
}

View File

@ -16,34 +16,32 @@ type PortRemap struct {
From uint16
To uint16
Enabled bool
enabled bool
}
func (r *PortRemap) PutIPTable(table string) error {
if table == "all" || table == "nat" {
err := r.IPTables.ClearChain("nat", r.ChainName)
func (r *PortRemap) insertIPTablesRules(table string) error {
if table == "" || table == "nat" {
err := r.IPTables.NewChain("nat", r.ChainName)
if err != nil {
return fmt.Errorf("failed to clear chain: %w", err)
// If not "AlreadyExists"
if eerr, eok := err.(*iptables.Error); !(eok && eerr.ExitStatus() == 1) {
return fmt.Errorf("failed to create chain: %w", err)
}
}
for _, addr := range r.Addresses {
var addrIP net.IP
iptablesProtocol := r.IPTables.Proto()
if (iptablesProtocol == iptables.ProtocolIPv4 && len(addr.IP) == net.IPv4len) || (iptablesProtocol == iptables.ProtocolIPv6 && len(addr.IP) == net.IPv6len) {
addrIP = addr.IP
}
if addrIP == nil {
if !((r.IPTables.Proto() == iptables.ProtocolIPv4 && len(addr.IP) == net.IPv4len) || (r.IPTables.Proto() == iptables.ProtocolIPv6 && len(addr.IP) == net.IPv6len)) {
continue
}
err = r.IPTables.AppendUnique("nat", r.ChainName, "-p", "udp", "-d", addrIP.String(), "--dport", strconv.Itoa(int(r.From)), "-j", "DNAT", "--to-destination", fmt.Sprintf(":%d", r.To))
for _, iptablesArgs := range [][]string{
{"-p", "tcp", "-d", addr.IP.String(), "--dport", strconv.Itoa(int(r.From)), "-j", "DNAT", "--to-destination", fmt.Sprintf(":%d", r.To)},
{"-p", "udp", "-d", addr.IP.String(), "--dport", strconv.Itoa(int(r.From)), "-j", "DNAT", "--to-destination", fmt.Sprintf(":%d", r.To)},
} {
err = r.IPTables.AppendUnique("nat", r.ChainName, iptablesArgs...)
if err != nil {
return fmt.Errorf("failed to create rule: %w", err)
return fmt.Errorf("failed to append rule: %w", err)
}
err = r.IPTables.AppendUnique("nat", r.ChainName, "-p", "tcp", "-d", addrIP.String(), "--dport", strconv.Itoa(int(r.From)), "-j", "DNAT", "--to-destination", fmt.Sprintf(":%d", r.To))
if err != nil {
return fmt.Errorf("failed to create rule: %w", err)
}
}
@ -56,17 +54,7 @@ func (r *PortRemap) PutIPTable(table string) error {
return nil
}
func (r *PortRemap) ForceEnable() error {
err := r.PutIPTable("all")
if err != nil {
return err
}
r.Enabled = true
return nil
}
func (r *PortRemap) Disable() []error {
func (r *PortRemap) deleteIPTablesRules() []error {
var errs []error
err := r.IPTables.DeleteIfExists("nat", "PREROUTING", "-j", r.ChainName)
@ -79,16 +67,30 @@ func (r *PortRemap) Disable() []error {
errs = append(errs, fmt.Errorf("failed to delete chain: %w", err))
}
r.Enabled = false
return errs
}
func (r *PortRemap) Enable() error {
if r.Enabled {
func (r *PortRemap) enable() error {
err := r.insertIPTablesRules("")
if err != nil {
return err
}
r.enabled = true
return nil
}
err := r.ForceEnable()
func (r *PortRemap) Enable() error {
if r.enabled {
return nil
}
err := r.IPTables.ClearChain("nat", r.ChainName)
if err != nil {
return fmt.Errorf("failed to clear chain: %w", err)
}
err = r.enable()
if err != nil {
r.Disable()
return err
@ -97,6 +99,19 @@ func (r *PortRemap) Enable() error {
return nil
}
func (r *PortRemap) Disable() []error {
errs := r.deleteIPTablesRules()
r.enabled = false
return errs
}
func (r *PortRemap) NetfilerDHook(table string) error {
if !r.enabled {
return nil
}
return r.insertIPTablesRules(table)
}
func (nh *NetfilterHelper) PortRemap(name string, from, to uint16, addr []netlink.Addr) *PortRemap {
return &PortRemap{
IPTables: nh.IPTables,