Compare commits

..

No commits in common. "ff6ab7b85976f5d3b9c37fa7926d44d7a76039b1" and "f818a86a1a4c71f1b1f753a49f93d01cf22ff3fc" have entirely different histories.

6 changed files with 78 additions and 79 deletions

View File

@ -21,17 +21,17 @@ type Group struct {
ipsetToLink *netfilterHelper.IPSetToLink ipsetToLink *netfilterHelper.IPSetToLink
} }
func (g *Group) AddIP(address net.IP, ttl time.Duration) error { func (g *Group) AddIPv4(address net.IP, ttl time.Duration) error {
ttlSeconds := uint32(ttl.Seconds()) ttlSeconds := uint32(ttl.Seconds())
return g.ipset.AddIP(address, &ttlSeconds) return g.ipset.AddIP(address, &ttlSeconds)
} }
func (g *Group) DelIP(address net.IP) error { func (g *Group) DelIPv4(address net.IP) error {
return g.ipset.DelIP(address) return g.ipset.Del(address)
} }
func (g *Group) ListIP() (map[string]*uint32, error) { func (g *Group) ListIPv4() (map[string]*uint32, error) {
return g.ipset.ListIPs() return g.ipset.List()
} }
func (g *Group) Enable() error { func (g *Group) Enable() error {

View File

@ -204,13 +204,17 @@ func (a *App) start(ctx context.Context) (err error) {
args := strings.Split(string(buf[:n]), ":") args := strings.Split(string(buf[:n]), ":")
if len(args) == 3 && args[0] == "netfilter.d" { if len(args) == 3 && args[0] == "netfilter.d" {
log.Debug().Str("table", args[2]).Msg("netfilter.d event") log.Debug().Str("table", args[2]).Msg("netfilter.d event")
err = a.dnsOverrider4.NetfilerDHook(args[2]) if a.dnsOverrider4.Enabled {
if err != nil { err := a.dnsOverrider4.PutIPTable(args[2])
log.Error().Err(err).Msg("error while fixing iptables after netfilter.d") if err != nil {
log.Error().Err(err).Msg("error while fixing iptables after netfilter.d")
}
} }
err = a.dnsOverrider6.NetfilerDHook(args[2]) if a.dnsOverrider6.Enabled {
if err != nil { err = a.dnsOverrider6.PutIPTable(args[2])
log.Error().Err(err).Msg("error while fixing iptables after netfilter.d") if err != nil {
log.Error().Err(err).Msg("error while fixing iptables after netfilter.d")
}
} }
for _, group := range a.Groups { for _, group := range a.Groups {
err := group.ipsetToLink.NetfilerDHook(args[2]) err := group.ipsetToLink.NetfilerDHook(args[2])
@ -291,7 +295,7 @@ func (a *App) AddGroup(group *models.Group) error {
Group: group, Group: group,
iptables: a.NetfilterHelper4.IPTables, iptables: a.NetfilterHelper4.IPTables,
ipset: ipset, ipset: ipset,
ipsetToLink: a.NetfilterHelper4.IPSetToLink(fmt.Sprintf("%sR_%8x", a.Config.ChainPrefix, group.ID.ID()), group.Interface, ipsetName, false), ipsetToLink: a.NetfilterHelper4.IfaceToIPSet(fmt.Sprintf("%sR_%8x", a.Config.ChainPrefix, group.ID.ID()), group.Interface, ipsetName, false),
} }
a.Groups[grp.ID] = grp a.Groups[grp.ID] = grp
return a.SyncGroup(grp) return a.SyncGroup(grp)
@ -322,7 +326,7 @@ func (a *App) SyncGroup(group *Group) error {
} }
} }
currentAddresses, err := group.ListIP() currentAddresses, err := group.ListIPv4()
if err != nil { if err != nil {
return fmt.Errorf("failed to get old ipset list: %w", err) return fmt.Errorf("failed to get old ipset list: %w", err)
} }
@ -333,7 +337,7 @@ func (a *App) SyncGroup(group *Group) error {
continue continue
} }
ip := net.IP(addr) ip := net.IP(addr)
err = group.AddIP(ip, ttl) err = group.AddIPv4(ip, ttl)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("address", ip.String()). Str("address", ip.String()).
@ -352,7 +356,7 @@ func (a *App) SyncGroup(group *Group) error {
continue continue
} }
ip := net.IP(addr) ip := net.IP(addr)
err = group.DelIP(ip) err = group.DelIPv4(ip)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("address", ip.String()). Str("address", ip.String()).
@ -414,7 +418,7 @@ func (a *App) processARecord(aRecord dns.A) {
continue continue
} }
// TODO: Check already existed // TODO: Check already existed
err := group.AddIP(aRecord.A, ttlDuration) err := group.AddIPv4(aRecord.A, ttlDuration)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("address", aRecord.A.String()). Str("address", aRecord.A.String()).
@ -463,7 +467,7 @@ func (a *App) processCNameRecord(cNameRecord dns.CNAME) {
continue continue
} }
for _, aRecord := range aRecords { for _, aRecord := range aRecords {
err := group.AddIP(aRecord.Address, now.Sub(aRecord.Deadline)) err := group.AddIPv4(aRecord.Address, now.Sub(aRecord.Deadline))
if err != nil { if err != nil {
log.Error(). log.Error().
Str("address", aRecord.Address.String()). Str("address", aRecord.Address.String()).
@ -557,7 +561,7 @@ func New(config Config) (*App, error) {
return nil, fmt.Errorf("netfilter helper init fail: %w", err) return nil, fmt.Errorf("netfilter helper init fail: %w", err)
} }
app.NetfilterHelper4 = nh4 app.NetfilterHelper4 = nh4
err = app.NetfilterHelper4.CleanIPTables(app.Config.ChainPrefix) err = app.NetfilterHelper4.ClearIPTables(app.Config.ChainPrefix)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to clear iptables: %w", err) return nil, fmt.Errorf("failed to clear iptables: %w", err)
} }
@ -567,7 +571,7 @@ func New(config Config) (*App, error) {
return nil, fmt.Errorf("netfilter helper init fail: %w", err) return nil, fmt.Errorf("netfilter helper init fail: %w", err)
} }
app.NetfilterHelper6 = nh6 app.NetfilterHelper6 = nh6
err = app.NetfilterHelper6.CleanIPTables(app.Config.ChainPrefix) err = app.NetfilterHelper6.ClearIPTables(app.Config.ChainPrefix)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to clear iptables: %w", err) return nil, fmt.Errorf("failed to clear iptables: %w", err)
} }

View File

@ -76,8 +76,9 @@ func (r *IPSetToLink) insertIPTablesRules(table string) error {
func (r *IPSetToLink) deleteIPTablesRules() []error { func (r *IPSetToLink) deleteIPTablesRules() []error {
var errs []error var errs []error
var err error
err := r.IPTables.DeleteIfExists("mangle", "PREROUTING", "-m", "set", "--match-set", r.IPSetName, "dst", "-j", r.ChainName) err = r.IPTables.DeleteIfExists("mangle", "PREROUTING", "-m", "set", "--match-set", r.IPSetName, "dst", "-j", r.ChainName)
if err != nil { if err != nil {
errs = append(errs, fmt.Errorf("failed to unlinking chain: %w", err)) errs = append(errs, fmt.Errorf("failed to unlinking chain: %w", err))
} }
@ -279,7 +280,7 @@ func (r *IPSetToLink) LinkUpdateHook() error {
return r.insertIPRoute() return r.insertIPRoute()
} }
func (nh *NetfilterHelper) IPSetToLink(name string, ifaceName, ipsetName string, softwareMode bool) *IPSetToLink { func (nh *NetfilterHelper) IfaceToIPSet(name string, ifaceName, ipsetName string, softwareMode bool) *IPSetToLink {
return &IPSetToLink{ return &IPSetToLink{
IPTables: nh.IPTables, IPTables: nh.IPTables,
ChainName: name, ChainName: name,

View File

@ -2,10 +2,9 @@ package netfilterHelper
import ( import (
"fmt" "fmt"
"github.com/vishvananda/netlink"
"net" "net"
"os" "os"
"github.com/vishvananda/netlink"
) )
type IPSet struct { type IPSet struct {
@ -24,7 +23,7 @@ func (r *IPSet) AddIP(addr net.IP, timeout *uint32) error {
return nil return nil
} }
func (r *IPSet) DelIP(addr net.IP) error { func (r *IPSet) Del(addr net.IP) error {
err := netlink.IpsetDel(r.SetName, &netlink.IPSetEntry{ err := netlink.IpsetDel(r.SetName, &netlink.IPSetEntry{
IP: addr, IP: addr,
}) })
@ -34,7 +33,7 @@ func (r *IPSet) DelIP(addr net.IP) error {
return nil return nil
} }
func (r *IPSet) ListIPs() (map[string]*uint32, error) { func (r *IPSet) List() (map[string]*uint32, error) {
list, err := netlink.IpsetList(r.SetName) list, err := netlink.IpsetList(r.SetName)
if err != nil { if err != nil {
return nil, err return nil, err
@ -63,8 +62,9 @@ func (nh *NetfilterHelper) IPSet(name string) (*IPSet, error) {
return nil, err return nil, err
} }
defaultTimeout := uint32(300)
err = netlink.IpsetCreate(ipset.SetName, "hash:net", netlink.IpsetCreateOptions{ err = netlink.IpsetCreate(ipset.SetName, "hash:net", netlink.IpsetCreateOptions{
Timeout: func(i uint32) *uint32 { return &i }(300), Timeout: &defaultTimeout,
}) })
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create ipset: %w", err) return nil, fmt.Errorf("failed to create ipset: %w", err)

View File

@ -5,9 +5,11 @@ import (
"strings" "strings"
) )
func (nh *NetfilterHelper) CleanIPTables(chainPrefix string) error { func (nh *NetfilterHelper) ClearIPTables(chainPrefix string) error {
jumpToChainPrefix := fmt.Sprintf("-j %s", chainPrefix) jumpToChainPrefix := fmt.Sprintf("-j %s", chainPrefix)
for _, table := range []string{"nat", "mangle", "filter"} { tableList := []string{"nat", "mangle", "filter"}
for _, table := range tableList {
chainListToDelete := make([]string, 0) chainListToDelete := make([]string, 0)
chains, err := nh.IPTables.ListChains(table) chains, err := nh.IPTables.ListChains(table)
@ -27,8 +29,15 @@ func (nh *NetfilterHelper) CleanIPTables(chainPrefix string) error {
} }
for _, rule := range rules { for _, rule := range rules {
if strings.Contains(rule, jumpToChainPrefix) { ruleSlice := strings.Split(rule, " ")
err = nh.IPTables.Delete(table, chain, rule) if len(ruleSlice) < 2 || ruleSlice[0] != "-A" || ruleSlice[1] != chain {
// TODO: Warn
continue
}
ruleSlice = ruleSlice[2:]
if strings.Contains(strings.Join(ruleSlice, " "), jumpToChainPrefix) {
err := nh.IPTables.Delete(table, chain, ruleSlice...)
if err != nil { if err != nil {
return fmt.Errorf("rule deletion error: %w", err) return fmt.Errorf("rule deletion error: %w", err)
} }
@ -37,7 +46,7 @@ func (nh *NetfilterHelper) CleanIPTables(chainPrefix string) error {
} }
for _, chain := range chainListToDelete { for _, chain := range chainListToDelete {
err = nh.IPTables.ClearAndDeleteChain(table, chain) err := nh.IPTables.ClearAndDeleteChain(table, chain)
if err != nil { if err != nil {
return fmt.Errorf("deleting chain error: %w", err) return fmt.Errorf("deleting chain error: %w", err)
} }

View File

@ -16,32 +16,34 @@ type PortRemap struct {
From uint16 From uint16
To uint16 To uint16
enabled bool Enabled bool
} }
func (r *PortRemap) insertIPTablesRules(table string) error { func (r *PortRemap) PutIPTable(table string) error {
if table == "" || table == "nat" { if table == "all" || table == "nat" {
err := r.IPTables.NewChain("nat", r.ChainName) err := r.IPTables.ClearChain("nat", r.ChainName)
if err != nil { if err != nil {
// If not "AlreadyExists" return fmt.Errorf("failed to clear chain: %w", err)
if eerr, eok := err.(*iptables.Error); !(eok && eerr.ExitStatus() == 1) {
return fmt.Errorf("failed to create chain: %w", err)
}
} }
for _, addr := range r.Addresses { for _, addr := range r.Addresses {
if !((r.IPTables.Proto() == iptables.ProtocolIPv4 && len(addr.IP) == net.IPv4len) || (r.IPTables.Proto() == iptables.ProtocolIPv6 && len(addr.IP) == net.IPv6len)) { var addrIP net.IP
iptablesProtocol := r.IPTables.Proto()
if (iptablesProtocol == iptables.ProtocolIPv4 && len(addr.IP) == net.IPv4len) || (iptablesProtocol == iptables.ProtocolIPv6 && len(addr.IP) == net.IPv6len) {
addrIP = addr.IP
}
if addrIP == nil {
continue continue
} }
for _, iptablesArgs := range [][]string{ err = r.IPTables.AppendUnique("nat", r.ChainName, "-p", "udp", "-d", addrIP.String(), "--dport", strconv.Itoa(int(r.From)), "-j", "DNAT", "--to-destination", fmt.Sprintf(":%d", r.To))
{"-p", "tcp", "-d", addr.IP.String(), "--dport", strconv.Itoa(int(r.From)), "-j", "DNAT", "--to-destination", fmt.Sprintf(":%d", r.To)}, if err != nil {
{"-p", "udp", "-d", addr.IP.String(), "--dport", strconv.Itoa(int(r.From)), "-j", "DNAT", "--to-destination", fmt.Sprintf(":%d", r.To)}, return fmt.Errorf("failed to create rule: %w", err)
} { }
err = r.IPTables.AppendUnique("nat", r.ChainName, iptablesArgs...)
if err != nil { err = r.IPTables.AppendUnique("nat", r.ChainName, "-p", "tcp", "-d", addrIP.String(), "--dport", strconv.Itoa(int(r.From)), "-j", "DNAT", "--to-destination", fmt.Sprintf(":%d", r.To))
return fmt.Errorf("failed to append rule: %w", err) if err != nil {
} return fmt.Errorf("failed to create rule: %w", err)
} }
} }
@ -54,7 +56,17 @@ func (r *PortRemap) insertIPTablesRules(table string) error {
return nil return nil
} }
func (r *PortRemap) deleteIPTablesRules() []error { func (r *PortRemap) ForceEnable() error {
err := r.PutIPTable("all")
if err != nil {
return err
}
r.Enabled = true
return nil
}
func (r *PortRemap) Disable() []error {
var errs []error var errs []error
err := r.IPTables.DeleteIfExists("nat", "PREROUTING", "-j", r.ChainName) err := r.IPTables.DeleteIfExists("nat", "PREROUTING", "-j", r.ChainName)
@ -67,30 +79,16 @@ func (r *PortRemap) deleteIPTablesRules() []error {
errs = append(errs, fmt.Errorf("failed to delete chain: %w", err)) errs = append(errs, fmt.Errorf("failed to delete chain: %w", err))
} }
r.Enabled = false
return errs return errs
} }
func (r *PortRemap) enable() error {
err := r.insertIPTablesRules("")
if err != nil {
return err
}
r.enabled = true
return nil
}
func (r *PortRemap) Enable() error { func (r *PortRemap) Enable() error {
if r.enabled { if r.Enabled {
return nil return nil
} }
err := r.IPTables.ClearChain("nat", r.ChainName) err := r.ForceEnable()
if err != nil {
return fmt.Errorf("failed to clear chain: %w", err)
}
err = r.enable()
if err != nil { if err != nil {
r.Disable() r.Disable()
return err return err
@ -99,19 +97,6 @@ func (r *PortRemap) Enable() error {
return nil return nil
} }
func (r *PortRemap) Disable() []error {
errs := r.deleteIPTablesRules()
r.enabled = false
return errs
}
func (r *PortRemap) NetfilerDHook(table string) error {
if !r.enabled {
return nil
}
return r.insertIPTablesRules(table)
}
func (nh *NetfilterHelper) PortRemap(name string, from, to uint16, addr []netlink.Addr) *PortRemap { func (nh *NetfilterHelper) PortRemap(name string, from, to uint16, addr []netlink.Addr) *PortRemap {
return &PortRemap{ return &PortRemap{
IPTables: nh.IPTables, IPTables: nh.IPTables,