diff --git a/connection/connection.go b/connection/connection.go index dcd202c..a6f9d1b 100644 --- a/connection/connection.go +++ b/connection/connection.go @@ -28,6 +28,8 @@ var DefaultConfig = Config{ type ( Connection struct { + mu sync.Mutex + base context.Context cancel context.CancelFunc conn atomic.Pointer[amqp091.Connection] config *Config @@ -62,26 +64,33 @@ func New(ctx context.Context, config Config, events Events) (*Connection, error) return nil, ErrOnConnectionReady } - ctx, cancel := context.WithCancel(ctx) - c := &Connection{ + base: ctx, config: &config, - cancel: cancel, onBeforeConnectionReady: events.OnBeforeConnectionReady, onConnectionReady: events.OnConnectionReady, onError: events.OnError, } c.once = sync.OnceFunc(func() { + c.mu.Lock() c.cancel() + c.mu.Unlock() c.connectionDispose() }) - return c.reconnect(ctx) + return c.reconnect() } -func (c *Connection) reconnect(ctx context.Context) (*Connection, error) { +func (c *Connection) reconnect() (*Connection, error) { connect := c.connect() + var ctx context.Context + c.mu.Lock() + if c.cancel != nil { + c.cancel() + } + ctx, c.cancel = context.WithCancel(c.base) + c.mu.Unlock() if err := connect(ctx); err == nil { return c, nil @@ -134,7 +143,7 @@ func (c *Connection) handleReconnect(ctx context.Context, connection *amqp091.Co c.connectionDispose() - if _, err := c.reconnect(ctx); err != nil { + if _, err := c.reconnect(); err != nil { return } } diff --git a/consumer/queue.go b/consumer/queue.go index 1b396c5..336703a 100644 --- a/consumer/queue.go +++ b/consumer/queue.go @@ -15,13 +15,11 @@ type ( ) func (c *Consumer[T]) Start(base context.Context) error { - _, cancel := context.WithCancel(base) - conn, err := connection.New(base, c.cfg.connectionOptions, connection.Events{ OnBeforeConnectionReady: func(ctx context.Context) error { - cancel() - _, cancel = context.WithCancel(ctx) - return c.watcher.Acquire(base, int64(c.cfg.queueConfig.Workers)) + defer c.watcher.Release(int64(c.cfg.queueConfig.Workers)) + // Here we wait for the workers to be released. + return c.watcher.Acquire(ctx, int64(c.cfg.queueConfig.Workers)) }, OnConnectionReady: func(ctx context.Context, connection *amqp091.Connection) error { diff --git a/publisher/publisher.go b/publisher/publisher.go index 6cadd7c..4f012cb 100644 --- a/publisher/publisher.go +++ b/publisher/publisher.go @@ -126,6 +126,8 @@ func (p *Publisher[T]) onConnectionReady(cfg Config[T]) connection.OnConnectionR return err } + defer p.semaphore.Release(1) + notifyClose, err := p.swapChannel(connection, cfg) if err != nil { return err