listen refactoring

This commit is contained in:
Vladimir Avtsenov 2024-10-21 00:03:51 +03:00
parent 53a13d5e90
commit 3e42d0deb0
2 changed files with 119 additions and 113 deletions

222
kvas2.go
View File

@ -5,8 +5,8 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"os"
"strings" "strings"
"sync"
"time" "time"
"kvas2-go/dns-proxy" "kvas2-go/dns-proxy"
@ -43,88 +43,106 @@ type App struct {
dnsOverrider4 *netfilterHelper.PortRemap dnsOverrider4 *netfilterHelper.PortRemap
} }
func (a *App) Listen(ctx context.Context) []error { func (a *App) handleLink(event netlink.LinkUpdate) {
if a.isRunning { switch event.Change {
return []error{ErrAlreadyRunning} case 0x00000001:
} log.Debug().
a.isRunning = true Str("interface", event.Link.Attrs().Name).
defer func() { a.isRunning = false }() Str("operstatestr", event.Attrs().OperState.String()).
Int("operstate", int(event.Attrs().OperState)).
errs := make([]error, 0) Msg("interface change")
isError := make(chan struct{}) if event.Attrs().OperState != netlink.OperDown {
for _, group := range a.Groups {
var once sync.Once if group.Interface == event.Link.Attrs().Name {
var errsMu sync.Mutex err := group.ifaceToIPSet.IfaceHandle()
handleError := func(err error) { if err != nil {
errsMu.Lock() log.Error().Int("group", group.ID).Err(err).Msg("error while handling interface up")
defer errsMu.Unlock() }
}
errs = append(errs, err)
once.Do(func() { close(isError) })
}
handleErrors := func(errs2 []error) {
errsMu.Lock()
defer errsMu.Unlock()
errs = append(errs, errs2...)
once.Do(func() { close(isError) })
}
defer func() {
if r := recover(); r != nil {
if err, ok := r.(error); ok {
handleError(err)
} else {
handleError(fmt.Errorf("%v", r))
} }
} }
}() case 0xFFFFFFFF:
switch event.Header.Type {
case 16:
log.Debug().
Str("interface", event.Link.Attrs().Name).
Int("type", int(event.Header.Type)).
Msg("interface add")
case 17:
log.Debug().
Str("interface", event.Link.Attrs().Name).
Int("type", int(event.Header.Type)).
Msg("interface del")
}
}
}
func (a *App) listen(ctx context.Context) (err error) {
errChan := make(chan error)
newCtx, cancel := context.WithCancel(ctx) newCtx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
go func() {
err := a.DNSProxy.Listen(newCtx)
if err != nil {
errChan <- fmt.Errorf("failed to serve DNS proxy: %v", err)
}
}()
a.dnsOverrider4 = a.NetfilterHelper4.PortRemap(fmt.Sprintf("%sDNSOR", a.Config.ChainPostfix), 53, a.Config.ListenPort) a.dnsOverrider4 = a.NetfilterHelper4.PortRemap(fmt.Sprintf("%sDNSOR", a.Config.ChainPostfix), 53, a.Config.ListenPort)
err := a.dnsOverrider4.Enable() err = a.dnsOverrider4.Enable()
if err != nil {
return fmt.Errorf("failed to override DNS: %v", err)
}
defer func() {
// TODO: Handle error
_ = a.dnsOverrider4.Disable()
}()
for _, group := range a.Groups { for _, group := range a.Groups {
err = group.Enable() err = group.Enable()
if err != nil { if err != nil {
handleError(fmt.Errorf("failed to enable group: %w", err)) return fmt.Errorf("failed to enable group: %w", err)
return errs
} }
} }
defer func() {
go func() { for _, group := range a.Groups {
if err := a.DNSProxy.Listen(newCtx); err != nil { // TODO: Handle error
handleError(fmt.Errorf("failed to initialize DNS proxy: %v", err)) _ = group.Disable()
} }
}() }()
link := make(chan netlink.LinkUpdate) socketPath := "/opt/var/run/kvas2-go.sock"
done := make(chan struct{}) err = os.Remove(socketPath)
netlink.LinkSubscribe(link, done) if err != nil && !errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("failed to remove existed UNIX socket: %w", err)
exitListenerLoop := false
listener, err := net.Listen("unix", "/opt/var/run/kvas2-go.sock")
if err != nil {
handleError(fmt.Errorf("error while serve UNIX socket: %v", err))
return errs
} }
defer listener.Close() socket, err := net.Listen("unix", socketPath)
if err != nil {
return fmt.Errorf("error while serve UNIX socket: %v", err)
}
defer func() {
// TODO: Handle error
_ = socket.Close()
_ = os.Remove(socketPath)
}()
go func() { go func() {
for { for {
if exitListenerLoop { conn, err := socket.Accept()
if err != nil {
if !strings.Contains(err.Error(), "use of closed network connection") {
log.Error().Err(err).Msg("error while listening unix socket")
}
break break
} }
conn, err := listener.Accept()
if err != nil {
log.Error().Err(err).Msg("error while listening unix socket")
}
go func(conn net.Conn) { go func(conn net.Conn) {
defer conn.Close() defer func() {
// TODO: Handle error
_ = conn.Close()
}()
buf := make([]byte, 1024) buf := make([]byte, 1024)
n, err := conn.Read(buf) n, err := conn.Read(buf)
@ -154,65 +172,55 @@ func (a *App) Listen(ctx context.Context) []error {
} }
}() }()
Loop: link := make(chan netlink.LinkUpdate)
done := make(chan struct{})
err = netlink.LinkSubscribe(link, done)
if err != nil {
return fmt.Errorf("failed to subscribe to link updates: %w", err)
}
defer func() {
close(done)
}()
for { for {
select { select {
case event := <-link: case event := <-link:
switch event.Change { a.handleLink(event)
case 0x00000001: case err := <-errChan:
log.Debug(). return err
Str("interface", event.Link.Attrs().Name).
Str("operstatestr", event.Attrs().OperState.String()).
Int("operstate", int(event.Attrs().OperState)).
Msg("interface change")
if event.Attrs().OperState != netlink.OperDown {
for _, group := range a.Groups {
if group.Interface == event.Link.Attrs().Name {
err = group.ifaceToIPSet.IfaceHandle()
if err != nil {
log.Error().Int("group", group.ID).Err(err).Msg("error while handling interface up")
}
}
}
}
case 0xFFFFFFFF:
switch event.Header.Type {
case 16:
log.Debug().
Str("interface", event.Link.Attrs().Name).
Int("type", int(event.Header.Type)).
Msg("interface add")
case 17:
log.Debug().
Str("interface", event.Link.Attrs().Name).
Int("type", int(event.Header.Type)).
Msg("interface del")
}
}
case <-ctx.Done(): case <-ctx.Done():
break Loop return nil
case <-isError:
break Loop
} }
} }
}
exitListenerLoop = true func (a *App) Listen(ctx context.Context) (err error) {
if a.isRunning {
close(done) return ErrAlreadyRunning
errs2 := a.dnsOverrider4.Disable()
if errs2 != nil {
handleErrors(errs2)
} }
a.isRunning = true
defer func() {
a.isRunning = false
}()
for _, group := range a.Groups { defer func() {
errs2 = group.Disable() if r := recover(); r != nil {
if errs2 != nil { var recoveredError error
handleErrors(errs2) var ok bool
if recoveredError, ok = r.(error); !ok {
recoveredError = fmt.Errorf("%v", r)
}
err = fmt.Errorf("recovered error: %w", recoveredError)
} }
}()
appErr := a.listen(ctx)
if appErr != nil {
return appErr
} }
return errs return err
} }
func (a *App) AddGroup(group *models.Group) error { func (a *App) AddGroup(group *models.Group) error {
@ -289,7 +297,7 @@ func (a *App) SyncGroup(group *Group) error {
} }
} }
for addr, _ := range oldIpsetAddresses { for addr := range oldIpsetAddresses {
if _, exists := newIpsetAddressesMap[addr]; exists { if _, exists := newIpsetAddressesMap[addr]; exists {
continue continue
} }

10
main.go
View File

@ -27,11 +27,9 @@ func main() {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
appErrsChan := make(chan []error) appResult := make(chan error)
go func() { go func() {
errs := app.Listen(ctx) appResult <- app.Listen(ctx)
appErrsChan <- errs
}() }()
log.Info().Msg("starting service") log.Info().Msg("starting service")
@ -41,8 +39,8 @@ func main() {
for { for {
select { select {
case appErrs, _ := <-appErrsChan: case err, _ := <-appResult:
for _, err = range appErrs { if err != nil {
log.Error().Err(err).Msg("failed to start application") log.Error().Err(err).Msg("failed to start application")
} }
log.Info().Msg("exiting application") log.Info().Msg("exiting application")