group refactoring
This commit is contained in:
parent
ff6ab7b859
commit
fdb1038ba9
79
group.go
79
group.go
@ -1,79 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"kvas2-go/models"
|
||||
"kvas2-go/netfilter-helper"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
)
|
||||
|
||||
type Group struct {
|
||||
*models.Group
|
||||
|
||||
Enabled bool
|
||||
|
||||
iptables *iptables.IPTables
|
||||
ipset *netfilterHelper.IPSet
|
||||
ipsetToLink *netfilterHelper.IPSetToLink
|
||||
}
|
||||
|
||||
func (g *Group) AddIP(address net.IP, ttl time.Duration) error {
|
||||
ttlSeconds := uint32(ttl.Seconds())
|
||||
return g.ipset.AddIP(address, &ttlSeconds)
|
||||
}
|
||||
|
||||
func (g *Group) DelIP(address net.IP) error {
|
||||
return g.ipset.DelIP(address)
|
||||
}
|
||||
|
||||
func (g *Group) ListIP() (map[string]*uint32, error) {
|
||||
return g.ipset.ListIPs()
|
||||
}
|
||||
|
||||
func (g *Group) Enable() error {
|
||||
if g.Enabled {
|
||||
return nil
|
||||
}
|
||||
defer func() {
|
||||
if !g.Enabled {
|
||||
_ = g.Disable()
|
||||
}
|
||||
}()
|
||||
|
||||
if g.FixProtect {
|
||||
err := g.iptables.AppendUnique("filter", "_NDM_SL_FORWARD", "-o", g.Interface, "-m", "state", "--state", "NEW", "-j", "_NDM_SL_PROTECT")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fix protect: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
err := g.ipsetToLink.Enable()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
g.Enabled = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *Group) Disable() []error {
|
||||
var errs []error
|
||||
|
||||
if !g.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := g.ipsetToLink.Disable()
|
||||
if err != nil {
|
||||
errs = append(errs, err...)
|
||||
}
|
||||
|
||||
g.Enabled = false
|
||||
|
||||
return errs
|
||||
}
|
184
group/group.go
Normal file
184
group/group.go
Normal file
@ -0,0 +1,184 @@
|
||||
package group
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"kvas2-go/models"
|
||||
"kvas2-go/netfilter-helper"
|
||||
"kvas2-go/records"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/vishvananda/netlink"
|
||||
)
|
||||
|
||||
type Group struct {
|
||||
*models.Group
|
||||
|
||||
enabled bool
|
||||
iptables *iptables.IPTables
|
||||
ipset *netfilterHelper.IPSet
|
||||
ipsetToLink *netfilterHelper.IPSetToLink
|
||||
}
|
||||
|
||||
func (g *Group) AddIP(address net.IP, ttl time.Duration) error {
|
||||
ttlSeconds := uint32(ttl.Seconds())
|
||||
return g.ipset.AddIP(address, &ttlSeconds)
|
||||
}
|
||||
|
||||
func (g *Group) DelIP(address net.IP) error {
|
||||
return g.ipset.DelIP(address)
|
||||
}
|
||||
|
||||
func (g *Group) ListIP() (map[string]*uint32, error) {
|
||||
return g.ipset.ListIPs()
|
||||
}
|
||||
|
||||
func (g *Group) Enable() error {
|
||||
if g.enabled {
|
||||
return nil
|
||||
}
|
||||
defer func() {
|
||||
if !g.enabled {
|
||||
_ = g.Disable()
|
||||
}
|
||||
}()
|
||||
|
||||
if g.FixProtect {
|
||||
err := g.iptables.AppendUnique("filter", "_NDM_SL_FORWARD", "-o", g.Interface, "-m", "state", "--state", "NEW", "-j", "_NDM_SL_PROTECT")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fix protect: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
err := g.ipsetToLink.Enable()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
g.enabled = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *Group) Disable() []error {
|
||||
var errs []error
|
||||
|
||||
if !g.enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
if g.FixProtect {
|
||||
err := g.iptables.Delete("filter", "_NDM_SL_FORWARD", "-o", g.Interface, "-m", "state", "--state", "NEW", "-j", "_NDM_SL_PROTECT")
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("failed to remove fix protect: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
err := g.ipsetToLink.Disable()
|
||||
if err != nil {
|
||||
errs = append(errs, err...)
|
||||
}
|
||||
|
||||
g.enabled = false
|
||||
|
||||
return errs
|
||||
}
|
||||
|
||||
func (g *Group) Sync(records *records.Records) error {
|
||||
now := time.Now()
|
||||
|
||||
addresses := make(map[string]time.Duration)
|
||||
knownDomains := records.ListKnownDomains()
|
||||
for _, domain := range g.Rules {
|
||||
if !domain.IsEnabled() {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, domainName := range knownDomains {
|
||||
if !domain.IsMatch(domainName) {
|
||||
continue
|
||||
}
|
||||
|
||||
domainAddresses := records.GetARecords(domainName)
|
||||
for _, address := range domainAddresses {
|
||||
ttl := now.Sub(address.Deadline)
|
||||
if oldTTL, ok := addresses[string(address.Address)]; !ok || ttl > oldTTL {
|
||||
addresses[string(address.Address)] = ttl
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
currentAddresses, err := g.ListIP()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get old ipset list: %w", err)
|
||||
}
|
||||
|
||||
for addr, ttl := range addresses {
|
||||
// TODO: Check TTL
|
||||
if _, exists := currentAddresses[addr]; exists {
|
||||
continue
|
||||
}
|
||||
ip := net.IP(addr)
|
||||
err = g.AddIP(ip, ttl)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("address", ip.String()).
|
||||
Err(err).
|
||||
Msg("failed to add address")
|
||||
} else {
|
||||
log.Trace().
|
||||
Str("address", ip.String()).
|
||||
Err(err).
|
||||
Msg("add address")
|
||||
}
|
||||
}
|
||||
|
||||
for addr := range currentAddresses {
|
||||
if _, ok := addresses[addr]; ok {
|
||||
continue
|
||||
}
|
||||
ip := net.IP(addr)
|
||||
err = g.DelIP(ip)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("address", ip.String()).
|
||||
Err(err).
|
||||
Msg("failed to delete address")
|
||||
} else {
|
||||
log.Trace().
|
||||
Str("address", ip.String()).
|
||||
Err(err).
|
||||
Msg("del address")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *Group) NetfilterDHook(table string) error {
|
||||
return g.ipsetToLink.NetfilterDHook(table)
|
||||
}
|
||||
|
||||
func (g *Group) LinkUpdateHook(event netlink.LinkUpdate) error {
|
||||
return g.ipsetToLink.LinkUpdateHook(event)
|
||||
}
|
||||
|
||||
func NewGroup(group *models.Group, nh4 *netfilterHelper.NetfilterHelper, chainPrefix, ipsetNamePrefix string) (*Group, error) {
|
||||
ipsetName := fmt.Sprintf("%s%8x", ipsetNamePrefix, group.ID.ID())
|
||||
ipset, err := nh4.IPSet(ipsetName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize ipset: %w", err)
|
||||
}
|
||||
|
||||
ipsetToLink := nh4.IPSetToLink(fmt.Sprintf("%s%8x", chainPrefix, group.ID.ID()), group.Interface, ipsetName)
|
||||
return &Group{
|
||||
Group: group,
|
||||
iptables: nh4.IPTables,
|
||||
ipset: ipset,
|
||||
ipsetToLink: ipsetToLink,
|
||||
}, nil
|
||||
}
|
105
kvas2.go
105
kvas2.go
@ -11,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"kvas2-go/dns-mitm-proxy"
|
||||
"kvas2-go/group"
|
||||
"kvas2-go/models"
|
||||
"kvas2-go/netfilter-helper"
|
||||
"kvas2-go/records"
|
||||
@ -43,7 +44,7 @@ type App struct {
|
||||
NetfilterHelper4 *netfilterHelper.NetfilterHelper
|
||||
NetfilterHelper6 *netfilterHelper.NetfilterHelper
|
||||
Records *records.Records
|
||||
Groups map[uuid.UUID]*Group
|
||||
Groups map[uuid.UUID]*group.Group
|
||||
|
||||
Link netlink.Link
|
||||
|
||||
@ -68,7 +69,7 @@ func (a *App) handleLink(event netlink.LinkUpdate) {
|
||||
continue
|
||||
}
|
||||
|
||||
err := group.ipsetToLink.LinkUpdateHook()
|
||||
err := group.LinkUpdateHook(event)
|
||||
if err != nil {
|
||||
log.Error().Str("group", group.ID.String()).Err(err).Msg("error while handling interface up")
|
||||
}
|
||||
@ -204,16 +205,16 @@ func (a *App) start(ctx context.Context) (err error) {
|
||||
args := strings.Split(string(buf[:n]), ":")
|
||||
if len(args) == 3 && args[0] == "netfilter.d" {
|
||||
log.Debug().Str("table", args[2]).Msg("netfilter.d event")
|
||||
err = a.dnsOverrider4.NetfilerDHook(args[2])
|
||||
err = a.dnsOverrider4.NetfilterDHook(args[2])
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("error while fixing iptables after netfilter.d")
|
||||
}
|
||||
err = a.dnsOverrider6.NetfilerDHook(args[2])
|
||||
err = a.dnsOverrider6.NetfilterDHook(args[2])
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("error while fixing iptables after netfilter.d")
|
||||
}
|
||||
for _, group := range a.Groups {
|
||||
err := group.ipsetToLink.NetfilerDHook(args[2])
|
||||
err := group.NetfilterDHook(args[2])
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("error while fixing iptables after netfilter.d")
|
||||
}
|
||||
@ -276,97 +277,17 @@ func (a *App) Start(ctx context.Context) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
func (a *App) AddGroup(group *models.Group) error {
|
||||
if _, exists := a.Groups[group.ID]; exists {
|
||||
func (a *App) AddGroup(groupModel *models.Group) error {
|
||||
if _, exists := a.Groups[groupModel.ID]; exists {
|
||||
return ErrGroupIDConflict
|
||||
}
|
||||
|
||||
ipsetName := fmt.Sprintf("%s%8x", a.Config.IpSetPrefix, group.ID.ID())
|
||||
ipset, err := a.NetfilterHelper4.IPSet(ipsetName)
|
||||
grp, err := group.NewGroup(groupModel, a.NetfilterHelper4, a.Config.ChainPrefix, a.Config.IpSetPrefix)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize ipset: %w", err)
|
||||
}
|
||||
|
||||
grp := &Group{
|
||||
Group: group,
|
||||
iptables: a.NetfilterHelper4.IPTables,
|
||||
ipset: ipset,
|
||||
ipsetToLink: a.NetfilterHelper4.IPSetToLink(fmt.Sprintf("%sR_%8x", a.Config.ChainPrefix, group.ID.ID()), group.Interface, ipsetName, false),
|
||||
return fmt.Errorf("failed to create group: %w", err)
|
||||
}
|
||||
a.Groups[grp.ID] = grp
|
||||
return a.SyncGroup(grp)
|
||||
}
|
||||
|
||||
func (a *App) SyncGroup(group *Group) error {
|
||||
now := time.Now()
|
||||
|
||||
addresses := make(map[string]time.Duration)
|
||||
knownDomains := a.Records.ListKnownDomains()
|
||||
for _, domain := range group.Rules {
|
||||
if !domain.IsEnabled() {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, domainName := range knownDomains {
|
||||
if !domain.IsMatch(domainName) {
|
||||
continue
|
||||
}
|
||||
|
||||
domainAddresses := a.Records.GetARecords(domainName)
|
||||
for _, address := range domainAddresses {
|
||||
ttl := now.Sub(address.Deadline)
|
||||
if oldTTL, ok := addresses[string(address.Address)]; !ok || ttl > oldTTL {
|
||||
addresses[string(address.Address)] = ttl
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
currentAddresses, err := group.ListIP()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get old ipset list: %w", err)
|
||||
}
|
||||
|
||||
for addr, ttl := range addresses {
|
||||
// TODO: Check TTL
|
||||
if _, exists := currentAddresses[addr]; exists {
|
||||
continue
|
||||
}
|
||||
ip := net.IP(addr)
|
||||
err = group.AddIP(ip, ttl)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("address", ip.String()).
|
||||
Err(err).
|
||||
Msg("failed to add address")
|
||||
} else {
|
||||
log.Trace().
|
||||
Str("address", ip.String()).
|
||||
Err(err).
|
||||
Msg("add address")
|
||||
}
|
||||
}
|
||||
|
||||
for addr := range currentAddresses {
|
||||
if _, ok := addresses[addr]; ok {
|
||||
continue
|
||||
}
|
||||
ip := net.IP(addr)
|
||||
err = group.DelIP(ip)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("address", ip.String()).
|
||||
Err(err).
|
||||
Msg("failed to delete address")
|
||||
} else {
|
||||
log.Trace().
|
||||
Str("address", ip.String()).
|
||||
Err(err).
|
||||
Msg("del address")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return grp.Sync(a.Records)
|
||||
}
|
||||
|
||||
func (a *App) ListInterfaces() ([]net.Interface, error) {
|
||||
@ -544,7 +465,7 @@ func New(config Config) (*App, error) {
|
||||
}
|
||||
|
||||
app.Records = records.New()
|
||||
app.Groups = make(map[uuid.UUID]*Group, 0)
|
||||
app.Groups = make(map[uuid.UUID]*group.Group)
|
||||
|
||||
link, err := netlink.LinkByName(app.Config.LinkName)
|
||||
if err != nil {
|
||||
@ -572,7 +493,7 @@ func New(config Config) (*App, error) {
|
||||
return nil, fmt.Errorf("failed to clear iptables: %w", err)
|
||||
}
|
||||
|
||||
app.Groups = make(map[uuid.UUID]*Group)
|
||||
app.Groups = make(map[uuid.UUID]*group.Group)
|
||||
|
||||
return app, nil
|
||||
}
|
||||
|
@ -265,21 +265,21 @@ func (r *IPSetToLink) Disable() []error {
|
||||
return errs
|
||||
}
|
||||
|
||||
func (r *IPSetToLink) NetfilerDHook(table string) error {
|
||||
func (r *IPSetToLink) NetfilterDHook(table string) error {
|
||||
if !r.enabled {
|
||||
return nil
|
||||
}
|
||||
return r.insertIPTablesRules(table)
|
||||
}
|
||||
|
||||
func (r *IPSetToLink) LinkUpdateHook() error {
|
||||
if !r.enabled {
|
||||
func (r *IPSetToLink) LinkUpdateHook(event netlink.LinkUpdate) error {
|
||||
if !r.enabled || event.Change != 1 || event.Link.Attrs().Name != r.IfaceName || event.Attrs().OperState != netlink.OperUp {
|
||||
return nil
|
||||
}
|
||||
return r.insertIPRoute()
|
||||
}
|
||||
|
||||
func (nh *NetfilterHelper) IPSetToLink(name string, ifaceName, ipsetName string, softwareMode bool) *IPSetToLink {
|
||||
func (nh *NetfilterHelper) IPSetToLink(name string, ifaceName, ipsetName string) *IPSetToLink {
|
||||
return &IPSetToLink{
|
||||
IPTables: nh.IPTables,
|
||||
ChainName: name,
|
||||
|
@ -105,7 +105,7 @@ func (r *PortRemap) Disable() []error {
|
||||
return errs
|
||||
}
|
||||
|
||||
func (r *PortRemap) NetfilerDHook(table string) error {
|
||||
func (r *PortRemap) NetfilterDHook(table string) error {
|
||||
if !r.enabled {
|
||||
return nil
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user