From 3e42d0deb0e209f9461b37983a120bbe950245ea Mon Sep 17 00:00:00 2001 From: Vladimir Avtsenov Date: Mon, 21 Oct 2024 00:03:51 +0300 Subject: [PATCH] listen refactoring --- kvas2.go | 222 ++++++++++++++++++++++++++++--------------------------- main.go | 10 +-- 2 files changed, 119 insertions(+), 113 deletions(-) diff --git a/kvas2.go b/kvas2.go index af12531..aff8866 100644 --- a/kvas2.go +++ b/kvas2.go @@ -5,8 +5,8 @@ import ( "errors" "fmt" "net" + "os" "strings" - "sync" "time" "kvas2-go/dns-proxy" @@ -43,88 +43,106 @@ type App struct { dnsOverrider4 *netfilterHelper.PortRemap } -func (a *App) Listen(ctx context.Context) []error { - if a.isRunning { - return []error{ErrAlreadyRunning} - } - a.isRunning = true - defer func() { a.isRunning = false }() - - errs := make([]error, 0) - isError := make(chan struct{}) - - var once sync.Once - var errsMu sync.Mutex - handleError := func(err error) { - errsMu.Lock() - defer errsMu.Unlock() - - errs = append(errs, err) - once.Do(func() { close(isError) }) - } - handleErrors := func(errs2 []error) { - errsMu.Lock() - defer errsMu.Unlock() - - errs = append(errs, errs2...) - once.Do(func() { close(isError) }) - } - - defer func() { - if r := recover(); r != nil { - if err, ok := r.(error); ok { - handleError(err) - } else { - handleError(fmt.Errorf("%v", r)) +func (a *App) handleLink(event netlink.LinkUpdate) { + switch event.Change { + case 0x00000001: + log.Debug(). + Str("interface", event.Link.Attrs().Name). + Str("operstatestr", event.Attrs().OperState.String()). + Int("operstate", int(event.Attrs().OperState)). + Msg("interface change") + if event.Attrs().OperState != netlink.OperDown { + for _, group := range a.Groups { + if group.Interface == event.Link.Attrs().Name { + err := group.ifaceToIPSet.IfaceHandle() + if err != nil { + log.Error().Int("group", group.ID).Err(err).Msg("error while handling interface up") + } + } } } - }() + case 0xFFFFFFFF: + switch event.Header.Type { + case 16: + log.Debug(). + Str("interface", event.Link.Attrs().Name). + Int("type", int(event.Header.Type)). + Msg("interface add") + case 17: + log.Debug(). + Str("interface", event.Link.Attrs().Name). + Int("type", int(event.Header.Type)). + Msg("interface del") + } + } +} + +func (a *App) listen(ctx context.Context) (err error) { + errChan := make(chan error) newCtx, cancel := context.WithCancel(ctx) defer cancel() + go func() { + err := a.DNSProxy.Listen(newCtx) + if err != nil { + errChan <- fmt.Errorf("failed to serve DNS proxy: %v", err) + } + }() + a.dnsOverrider4 = a.NetfilterHelper4.PortRemap(fmt.Sprintf("%sDNSOR", a.Config.ChainPostfix), 53, a.Config.ListenPort) - err := a.dnsOverrider4.Enable() + err = a.dnsOverrider4.Enable() + if err != nil { + return fmt.Errorf("failed to override DNS: %v", err) + } + defer func() { + // TODO: Handle error + _ = a.dnsOverrider4.Disable() + }() for _, group := range a.Groups { err = group.Enable() if err != nil { - handleError(fmt.Errorf("failed to enable group: %w", err)) - return errs + return fmt.Errorf("failed to enable group: %w", err) } } - - go func() { - if err := a.DNSProxy.Listen(newCtx); err != nil { - handleError(fmt.Errorf("failed to initialize DNS proxy: %v", err)) + defer func() { + for _, group := range a.Groups { + // TODO: Handle error + _ = group.Disable() } }() - link := make(chan netlink.LinkUpdate) - done := make(chan struct{}) - netlink.LinkSubscribe(link, done) - - exitListenerLoop := false - listener, err := net.Listen("unix", "/opt/var/run/kvas2-go.sock") - if err != nil { - handleError(fmt.Errorf("error while serve UNIX socket: %v", err)) - return errs + socketPath := "/opt/var/run/kvas2-go.sock" + err = os.Remove(socketPath) + if err != nil && !errors.Is(err, os.ErrNotExist) { + return fmt.Errorf("failed to remove existed UNIX socket: %w", err) } - defer listener.Close() + socket, err := net.Listen("unix", socketPath) + if err != nil { + return fmt.Errorf("error while serve UNIX socket: %v", err) + } + defer func() { + // TODO: Handle error + _ = socket.Close() + _ = os.Remove(socketPath) + }() go func() { for { - if exitListenerLoop { + conn, err := socket.Accept() + if err != nil { + if !strings.Contains(err.Error(), "use of closed network connection") { + log.Error().Err(err).Msg("error while listening unix socket") + } break } - conn, err := listener.Accept() - if err != nil { - log.Error().Err(err).Msg("error while listening unix socket") - } - go func(conn net.Conn) { - defer conn.Close() + defer func() { + // TODO: Handle error + _ = conn.Close() + }() buf := make([]byte, 1024) n, err := conn.Read(buf) @@ -154,65 +172,55 @@ func (a *App) Listen(ctx context.Context) []error { } }() -Loop: + link := make(chan netlink.LinkUpdate) + done := make(chan struct{}) + err = netlink.LinkSubscribe(link, done) + if err != nil { + return fmt.Errorf("failed to subscribe to link updates: %w", err) + } + defer func() { + close(done) + }() + for { select { case event := <-link: - switch event.Change { - case 0x00000001: - log.Debug(). - Str("interface", event.Link.Attrs().Name). - Str("operstatestr", event.Attrs().OperState.String()). - Int("operstate", int(event.Attrs().OperState)). - Msg("interface change") - if event.Attrs().OperState != netlink.OperDown { - for _, group := range a.Groups { - if group.Interface == event.Link.Attrs().Name { - err = group.ifaceToIPSet.IfaceHandle() - if err != nil { - log.Error().Int("group", group.ID).Err(err).Msg("error while handling interface up") - } - } - } - } - case 0xFFFFFFFF: - switch event.Header.Type { - case 16: - log.Debug(). - Str("interface", event.Link.Attrs().Name). - Int("type", int(event.Header.Type)). - Msg("interface add") - case 17: - log.Debug(). - Str("interface", event.Link.Attrs().Name). - Int("type", int(event.Header.Type)). - Msg("interface del") - } - } + a.handleLink(event) + case err := <-errChan: + return err case <-ctx.Done(): - break Loop - case <-isError: - break Loop + return nil } } +} - exitListenerLoop = true - - close(done) - - errs2 := a.dnsOverrider4.Disable() - if errs2 != nil { - handleErrors(errs2) +func (a *App) Listen(ctx context.Context) (err error) { + if a.isRunning { + return ErrAlreadyRunning } + a.isRunning = true + defer func() { + a.isRunning = false + }() - for _, group := range a.Groups { - errs2 = group.Disable() - if errs2 != nil { - handleErrors(errs2) + defer func() { + if r := recover(); r != nil { + var recoveredError error + var ok bool + if recoveredError, ok = r.(error); !ok { + recoveredError = fmt.Errorf("%v", r) + } + + err = fmt.Errorf("recovered error: %w", recoveredError) } + }() + + appErr := a.listen(ctx) + if appErr != nil { + return appErr } - return errs + return err } func (a *App) AddGroup(group *models.Group) error { @@ -289,7 +297,7 @@ func (a *App) SyncGroup(group *Group) error { } } - for addr, _ := range oldIpsetAddresses { + for addr := range oldIpsetAddresses { if _, exists := newIpsetAddressesMap[addr]; exists { continue } diff --git a/main.go b/main.go index 0eec100..22e3c73 100644 --- a/main.go +++ b/main.go @@ -27,11 +27,9 @@ func main() { ctx, cancel := context.WithCancel(context.Background()) - appErrsChan := make(chan []error) + appResult := make(chan error) go func() { - errs := app.Listen(ctx) - appErrsChan <- errs - + appResult <- app.Listen(ctx) }() log.Info().Msg("starting service") @@ -41,8 +39,8 @@ func main() { for { select { - case appErrs, _ := <-appErrsChan: - for _, err = range appErrs { + case err, _ := <-appResult: + if err != nil { log.Error().Err(err).Msg("failed to start application") } log.Info().Msg("exiting application")