Compare commits

..

No commits in common. "ff6ab7b85976f5d3b9c37fa7926d44d7a76039b1" and "f818a86a1a4c71f1b1f753a49f93d01cf22ff3fc" have entirely different histories.

6 changed files with 78 additions and 79 deletions

View File

@ -21,17 +21,17 @@ type Group struct {
ipsetToLink *netfilterHelper.IPSetToLink
}
func (g *Group) AddIP(address net.IP, ttl time.Duration) error {
func (g *Group) AddIPv4(address net.IP, ttl time.Duration) error {
ttlSeconds := uint32(ttl.Seconds())
return g.ipset.AddIP(address, &ttlSeconds)
}
func (g *Group) DelIP(address net.IP) error {
return g.ipset.DelIP(address)
func (g *Group) DelIPv4(address net.IP) error {
return g.ipset.Del(address)
}
func (g *Group) ListIP() (map[string]*uint32, error) {
return g.ipset.ListIPs()
func (g *Group) ListIPv4() (map[string]*uint32, error) {
return g.ipset.List()
}
func (g *Group) Enable() error {

View File

@ -204,13 +204,17 @@ func (a *App) start(ctx context.Context) (err error) {
args := strings.Split(string(buf[:n]), ":")
if len(args) == 3 && args[0] == "netfilter.d" {
log.Debug().Str("table", args[2]).Msg("netfilter.d event")
err = a.dnsOverrider4.NetfilerDHook(args[2])
if err != nil {
log.Error().Err(err).Msg("error while fixing iptables after netfilter.d")
if a.dnsOverrider4.Enabled {
err := a.dnsOverrider4.PutIPTable(args[2])
if err != nil {
log.Error().Err(err).Msg("error while fixing iptables after netfilter.d")
}
}
err = a.dnsOverrider6.NetfilerDHook(args[2])
if err != nil {
log.Error().Err(err).Msg("error while fixing iptables after netfilter.d")
if a.dnsOverrider6.Enabled {
err = a.dnsOverrider6.PutIPTable(args[2])
if err != nil {
log.Error().Err(err).Msg("error while fixing iptables after netfilter.d")
}
}
for _, group := range a.Groups {
err := group.ipsetToLink.NetfilerDHook(args[2])
@ -291,7 +295,7 @@ func (a *App) AddGroup(group *models.Group) error {
Group: group,
iptables: a.NetfilterHelper4.IPTables,
ipset: ipset,
ipsetToLink: a.NetfilterHelper4.IPSetToLink(fmt.Sprintf("%sR_%8x", a.Config.ChainPrefix, group.ID.ID()), group.Interface, ipsetName, false),
ipsetToLink: a.NetfilterHelper4.IfaceToIPSet(fmt.Sprintf("%sR_%8x", a.Config.ChainPrefix, group.ID.ID()), group.Interface, ipsetName, false),
}
a.Groups[grp.ID] = grp
return a.SyncGroup(grp)
@ -322,7 +326,7 @@ func (a *App) SyncGroup(group *Group) error {
}
}
currentAddresses, err := group.ListIP()
currentAddresses, err := group.ListIPv4()
if err != nil {
return fmt.Errorf("failed to get old ipset list: %w", err)
}
@ -333,7 +337,7 @@ func (a *App) SyncGroup(group *Group) error {
continue
}
ip := net.IP(addr)
err = group.AddIP(ip, ttl)
err = group.AddIPv4(ip, ttl)
if err != nil {
log.Error().
Str("address", ip.String()).
@ -352,7 +356,7 @@ func (a *App) SyncGroup(group *Group) error {
continue
}
ip := net.IP(addr)
err = group.DelIP(ip)
err = group.DelIPv4(ip)
if err != nil {
log.Error().
Str("address", ip.String()).
@ -414,7 +418,7 @@ func (a *App) processARecord(aRecord dns.A) {
continue
}
// TODO: Check already existed
err := group.AddIP(aRecord.A, ttlDuration)
err := group.AddIPv4(aRecord.A, ttlDuration)
if err != nil {
log.Error().
Str("address", aRecord.A.String()).
@ -463,7 +467,7 @@ func (a *App) processCNameRecord(cNameRecord dns.CNAME) {
continue
}
for _, aRecord := range aRecords {
err := group.AddIP(aRecord.Address, now.Sub(aRecord.Deadline))
err := group.AddIPv4(aRecord.Address, now.Sub(aRecord.Deadline))
if err != nil {
log.Error().
Str("address", aRecord.Address.String()).
@ -557,7 +561,7 @@ func New(config Config) (*App, error) {
return nil, fmt.Errorf("netfilter helper init fail: %w", err)
}
app.NetfilterHelper4 = nh4
err = app.NetfilterHelper4.CleanIPTables(app.Config.ChainPrefix)
err = app.NetfilterHelper4.ClearIPTables(app.Config.ChainPrefix)
if err != nil {
return nil, fmt.Errorf("failed to clear iptables: %w", err)
}
@ -567,7 +571,7 @@ func New(config Config) (*App, error) {
return nil, fmt.Errorf("netfilter helper init fail: %w", err)
}
app.NetfilterHelper6 = nh6
err = app.NetfilterHelper6.CleanIPTables(app.Config.ChainPrefix)
err = app.NetfilterHelper6.ClearIPTables(app.Config.ChainPrefix)
if err != nil {
return nil, fmt.Errorf("failed to clear iptables: %w", err)
}

View File

@ -76,8 +76,9 @@ func (r *IPSetToLink) insertIPTablesRules(table string) error {
func (r *IPSetToLink) deleteIPTablesRules() []error {
var errs []error
var err error
err := r.IPTables.DeleteIfExists("mangle", "PREROUTING", "-m", "set", "--match-set", r.IPSetName, "dst", "-j", r.ChainName)
err = r.IPTables.DeleteIfExists("mangle", "PREROUTING", "-m", "set", "--match-set", r.IPSetName, "dst", "-j", r.ChainName)
if err != nil {
errs = append(errs, fmt.Errorf("failed to unlinking chain: %w", err))
}
@ -279,7 +280,7 @@ func (r *IPSetToLink) LinkUpdateHook() error {
return r.insertIPRoute()
}
func (nh *NetfilterHelper) IPSetToLink(name string, ifaceName, ipsetName string, softwareMode bool) *IPSetToLink {
func (nh *NetfilterHelper) IfaceToIPSet(name string, ifaceName, ipsetName string, softwareMode bool) *IPSetToLink {
return &IPSetToLink{
IPTables: nh.IPTables,
ChainName: name,

View File

@ -2,10 +2,9 @@ package netfilterHelper
import (
"fmt"
"github.com/vishvananda/netlink"
"net"
"os"
"github.com/vishvananda/netlink"
)
type IPSet struct {
@ -24,7 +23,7 @@ func (r *IPSet) AddIP(addr net.IP, timeout *uint32) error {
return nil
}
func (r *IPSet) DelIP(addr net.IP) error {
func (r *IPSet) Del(addr net.IP) error {
err := netlink.IpsetDel(r.SetName, &netlink.IPSetEntry{
IP: addr,
})
@ -34,7 +33,7 @@ func (r *IPSet) DelIP(addr net.IP) error {
return nil
}
func (r *IPSet) ListIPs() (map[string]*uint32, error) {
func (r *IPSet) List() (map[string]*uint32, error) {
list, err := netlink.IpsetList(r.SetName)
if err != nil {
return nil, err
@ -63,8 +62,9 @@ func (nh *NetfilterHelper) IPSet(name string) (*IPSet, error) {
return nil, err
}
defaultTimeout := uint32(300)
err = netlink.IpsetCreate(ipset.SetName, "hash:net", netlink.IpsetCreateOptions{
Timeout: func(i uint32) *uint32 { return &i }(300),
Timeout: &defaultTimeout,
})
if err != nil {
return nil, fmt.Errorf("failed to create ipset: %w", err)

View File

@ -5,9 +5,11 @@ import (
"strings"
)
func (nh *NetfilterHelper) CleanIPTables(chainPrefix string) error {
func (nh *NetfilterHelper) ClearIPTables(chainPrefix string) error {
jumpToChainPrefix := fmt.Sprintf("-j %s", chainPrefix)
for _, table := range []string{"nat", "mangle", "filter"} {
tableList := []string{"nat", "mangle", "filter"}
for _, table := range tableList {
chainListToDelete := make([]string, 0)
chains, err := nh.IPTables.ListChains(table)
@ -27,8 +29,15 @@ func (nh *NetfilterHelper) CleanIPTables(chainPrefix string) error {
}
for _, rule := range rules {
if strings.Contains(rule, jumpToChainPrefix) {
err = nh.IPTables.Delete(table, chain, rule)
ruleSlice := strings.Split(rule, " ")
if len(ruleSlice) < 2 || ruleSlice[0] != "-A" || ruleSlice[1] != chain {
// TODO: Warn
continue
}
ruleSlice = ruleSlice[2:]
if strings.Contains(strings.Join(ruleSlice, " "), jumpToChainPrefix) {
err := nh.IPTables.Delete(table, chain, ruleSlice...)
if err != nil {
return fmt.Errorf("rule deletion error: %w", err)
}
@ -37,7 +46,7 @@ func (nh *NetfilterHelper) CleanIPTables(chainPrefix string) error {
}
for _, chain := range chainListToDelete {
err = nh.IPTables.ClearAndDeleteChain(table, chain)
err := nh.IPTables.ClearAndDeleteChain(table, chain)
if err != nil {
return fmt.Errorf("deleting chain error: %w", err)
}

View File

@ -16,32 +16,34 @@ type PortRemap struct {
From uint16
To uint16
enabled bool
Enabled bool
}
func (r *PortRemap) insertIPTablesRules(table string) error {
if table == "" || table == "nat" {
err := r.IPTables.NewChain("nat", r.ChainName)
func (r *PortRemap) PutIPTable(table string) error {
if table == "all" || table == "nat" {
err := r.IPTables.ClearChain("nat", r.ChainName)
if err != nil {
// If not "AlreadyExists"
if eerr, eok := err.(*iptables.Error); !(eok && eerr.ExitStatus() == 1) {
return fmt.Errorf("failed to create chain: %w", err)
}
return fmt.Errorf("failed to clear chain: %w", err)
}
for _, addr := range r.Addresses {
if !((r.IPTables.Proto() == iptables.ProtocolIPv4 && len(addr.IP) == net.IPv4len) || (r.IPTables.Proto() == iptables.ProtocolIPv6 && len(addr.IP) == net.IPv6len)) {
var addrIP net.IP
iptablesProtocol := r.IPTables.Proto()
if (iptablesProtocol == iptables.ProtocolIPv4 && len(addr.IP) == net.IPv4len) || (iptablesProtocol == iptables.ProtocolIPv6 && len(addr.IP) == net.IPv6len) {
addrIP = addr.IP
}
if addrIP == nil {
continue
}
for _, iptablesArgs := range [][]string{
{"-p", "tcp", "-d", addr.IP.String(), "--dport", strconv.Itoa(int(r.From)), "-j", "DNAT", "--to-destination", fmt.Sprintf(":%d", r.To)},
{"-p", "udp", "-d", addr.IP.String(), "--dport", strconv.Itoa(int(r.From)), "-j", "DNAT", "--to-destination", fmt.Sprintf(":%d", r.To)},
} {
err = r.IPTables.AppendUnique("nat", r.ChainName, iptablesArgs...)
if err != nil {
return fmt.Errorf("failed to append rule: %w", err)
}
err = r.IPTables.AppendUnique("nat", r.ChainName, "-p", "udp", "-d", addrIP.String(), "--dport", strconv.Itoa(int(r.From)), "-j", "DNAT", "--to-destination", fmt.Sprintf(":%d", r.To))
if err != nil {
return fmt.Errorf("failed to create rule: %w", err)
}
err = r.IPTables.AppendUnique("nat", r.ChainName, "-p", "tcp", "-d", addrIP.String(), "--dport", strconv.Itoa(int(r.From)), "-j", "DNAT", "--to-destination", fmt.Sprintf(":%d", r.To))
if err != nil {
return fmt.Errorf("failed to create rule: %w", err)
}
}
@ -54,7 +56,17 @@ func (r *PortRemap) insertIPTablesRules(table string) error {
return nil
}
func (r *PortRemap) deleteIPTablesRules() []error {
func (r *PortRemap) ForceEnable() error {
err := r.PutIPTable("all")
if err != nil {
return err
}
r.Enabled = true
return nil
}
func (r *PortRemap) Disable() []error {
var errs []error
err := r.IPTables.DeleteIfExists("nat", "PREROUTING", "-j", r.ChainName)
@ -67,30 +79,16 @@ func (r *PortRemap) deleteIPTablesRules() []error {
errs = append(errs, fmt.Errorf("failed to delete chain: %w", err))
}
r.Enabled = false
return errs
}
func (r *PortRemap) enable() error {
err := r.insertIPTablesRules("")
if err != nil {
return err
}
r.enabled = true
return nil
}
func (r *PortRemap) Enable() error {
if r.enabled {
if r.Enabled {
return nil
}
err := r.IPTables.ClearChain("nat", r.ChainName)
if err != nil {
return fmt.Errorf("failed to clear chain: %w", err)
}
err = r.enable()
err := r.ForceEnable()
if err != nil {
r.Disable()
return err
@ -99,19 +97,6 @@ func (r *PortRemap) Enable() error {
return nil
}
func (r *PortRemap) Disable() []error {
errs := r.deleteIPTablesRules()
r.enabled = false
return errs
}
func (r *PortRemap) NetfilerDHook(table string) error {
if !r.enabled {
return nil
}
return r.insertIPTablesRules(table)
}
func (nh *NetfilterHelper) PortRemap(name string, from, to uint16, addr []netlink.Addr) *PortRemap {
return &PortRemap{
IPTables: nh.IPTables,