diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index f336ed7..7561fb3 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -10,41 +10,41 @@ jobs: name: Run tests strategy: matrix: - go-version: [1.23.x, 1.22.x, 1.21.x, 1.20.x] + go-version: [1.23.x, 1.22.x, 1.21.x] os: [ubuntu-latest, macos-latest, windows-latest] runs-on: ${{ matrix.os }} steps: - - name: Checkout code - uses: actions/checkout@v4 + - name: Checkout code + uses: actions/checkout@v4 - - name: Install Go - uses: actions/setup-go@v5 - with: - go-version: ${{ matrix.go-version }} - cache: false + - name: Install Go + uses: actions/setup-go@v5 + with: + go-version: ${{ matrix.go-version }} + cache: false - - name: Test - run: make test-ci + - name: Test + run: make test-ci codecov: name: Coverage report runs-on: ubuntu-latest steps: - - name: Checkout code - uses: actions/checkout@v4 + - name: Checkout code + uses: actions/checkout@v4 - - name: Install Go - uses: actions/setup-go@v5 - with: - go-version-file: go.mod - cache: false + - name: Install Go + uses: actions/setup-go@v5 + with: + go-version-file: go.mod + cache: false - - name: Test - run: make coverage + - name: Test + run: make coverage - - uses: codecov/codecov-action@v5 - with: - files: coverage.out - fail_ci_if_error: true - verbose: true - env: - CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} \ No newline at end of file + - uses: codecov/codecov-action@v5 + with: + files: coverage.out + fail_ci_if_error: true + verbose: true + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} diff --git a/README.md b/README.md index e007f12..1a8161e 100644 --- a/README.md +++ b/README.md @@ -35,11 +35,13 @@ Some common use cases include: - Complete pool metrics such as number of running workers, tasks waiting in the queue [and more](#metrics--monitoring) - Configurable parent context to stop all workers when it is cancelled - **New features in v2**: - - Unbounded task queues + - Bounded or Unbounded task queues - Submission of tasks that return results - Awaitable task completion - Type safe APIs for tasks that return errors or results - Panics recovery (panics are captured and returned as errors) + - Subpools with a fraction of the parent pool's maximum number of workers + - Blocking and non-blocking submission of tasks when the queue is full - [API reference](https://pkg.go.dev/github.com/alitto/pond/v2) ## Installation @@ -386,6 +388,24 @@ if err != nil { } ``` +### Bounded task queues (v2) + +By default, task queues are unbounded, meaning that tasks are queued indefinitely until the pool is stopped (or the process runs out of memory). You can limit the number of tasks that can be queued by setting a queue size when creating a pool (`WithQueueSize` option). + +``` go +// Create a pool with a maximum of 10 tasks in the queue +pool := pond.NewPool(1, pond.WithQueueSize(10)) +``` + +**Blocking vs non-blocking task submission** + +When a pool defines a queue size (bounded), you can also specify how to handle tasks submitted when the queue is full. By default, task submission blocks until there is space in the queue (blocking mode), but you can change this behavior to non-blocking by setting the `WithNonBlocking` option to `true` when creating a pool. If the queue is full and non-blocking task submission is enabled, the task is dropped and an error is returned (`ErrQueueFull`). + +``` go +// Create a pool with a maximum of 10 tasks in the queue and non-blocking task submission +pool := pond.NewPool(1, pond.WithQueueSize(10), pond.WithNonBlocking(true)) +``` + ### Metrics & monitoring Each worker pool instance exposes useful metrics that can be queried through the following methods: diff --git a/go.mod b/go.mod index 37f1ff8..95a419c 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module github.com/alitto/pond/v2 -go 1.20 +go 1.21 diff --git a/group_test.go b/group_test.go index 8e900d2..6758f84 100644 --- a/group_test.go +++ b/group_test.go @@ -192,19 +192,23 @@ func TestTaskGroupWithContextCanceled(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) - go func() { - time.Sleep(10 * time.Millisecond) - cancel() - }() + taskStarted := make(chan struct{}) + + task := group.SubmitErr(func() error { + taskStarted <- struct{}{} - err := group.SubmitErr(func() error { select { case <-ctx.Done(): return ctx.Err() case <-time.After(100 * time.Millisecond): return nil } - }).Wait() + }) + + <-taskStarted + cancel() + + err := task.Wait() assert.Equal(t, context.Canceled, err) } diff --git a/internal/linkedbuffer/linkedbuffer.go b/internal/linkedbuffer/linkedbuffer.go index 41651c6..739318f 100644 --- a/internal/linkedbuffer/linkedbuffer.go +++ b/internal/linkedbuffer/linkedbuffer.go @@ -1,6 +1,7 @@ package linkedbuffer import ( + "math" "sync" "sync/atomic" ) @@ -125,7 +126,8 @@ func (b *LinkedBuffer[T]) Len() uint64 { readCount := b.readCount.Load() if writeCount < readCount { - return 0 // Make sure we don't return a negative value + // The writeCount counter wrapped around + return math.MaxUint64 - readCount + writeCount } return writeCount - readCount diff --git a/internal/linkedbuffer/linkedbuffer_test.go b/internal/linkedbuffer/linkedbuffer_test.go index 0a7cbc1..db46f4b 100644 --- a/internal/linkedbuffer/linkedbuffer_test.go +++ b/internal/linkedbuffer/linkedbuffer_test.go @@ -1,6 +1,7 @@ package linkedbuffer import ( + "math" "sync" "sync/atomic" "testing" @@ -88,8 +89,10 @@ func TestLinkedBufferLen(t *testing.T) { assert.Equal(t, uint64(0), buf.Len()) - buf.readCount.Add(1) - assert.Equal(t, uint64(0), buf.Len()) + // Test wrap around + buf.writeCount.Add(math.MaxUint64) + buf.readCount.Add(math.MaxUint64 - 3) + assert.Equal(t, uint64(3), buf.Len()) } func TestLinkedBufferWithReusedBuffer(t *testing.T) { diff --git a/internal/semaphore/semaphore.go b/internal/semaphore/semaphore.go new file mode 100644 index 0000000..78c735c --- /dev/null +++ b/internal/semaphore/semaphore.go @@ -0,0 +1,142 @@ +package semaphore + +import ( + "context" + "fmt" + "sync" +) + +type Weighted struct { + ctx context.Context + cond *sync.Cond + size int + n int + waiting int +} + +func NewWeighted(ctx context.Context, size int) *Weighted { + sem := &Weighted{ + ctx: ctx, + cond: sync.NewCond(&sync.Mutex{}), + size: size, + n: size, + } + + // Notify all waiters when the context is done + context.AfterFunc(ctx, func() { + sem.cond.Broadcast() + }) + + return sem +} + +func (w *Weighted) Acquire(weight int) error { + if weight <= 0 { + return fmt.Errorf("semaphore: weight %d cannot be negative or zero", weight) + } + if weight > w.size { + return fmt.Errorf("semaphore: weight %d is greater than semaphore size %d", weight, w.size) + } + + w.cond.L.Lock() + defer w.cond.L.Unlock() + + done := w.ctx.Done() + + select { + case <-done: + return w.ctx.Err() + default: + } + + for weight > w.n { + // Check if the context is done + select { + case <-done: + return w.ctx.Err() + default: + } + + w.waiting++ + w.cond.Wait() + w.waiting-- + } + + w.n -= weight + + return nil +} + +func (w *Weighted) TryAcquire(weight int) bool { + if weight <= 0 { + return false + } + if weight > w.size { + return false + } + + w.cond.L.Lock() + defer w.cond.L.Unlock() + + // Check if the context is done + select { + case <-w.ctx.Done(): + return false + default: + } + + if weight > w.n { + // Not enough room in the semaphore + return false + } + + w.n -= weight + + return true +} + +func (w *Weighted) Release(weight int) error { + if weight <= 0 { + return fmt.Errorf("semaphore: weight %d cannot be negative or zero", weight) + } + if weight > w.size { + return fmt.Errorf("semaphore: weight %d is greater than semaphore size %d", weight, w.size) + } + + w.cond.L.Lock() + defer w.cond.L.Unlock() + + if weight > w.size-w.n { + return fmt.Errorf("semaphore: trying to release more than acquired: %d > %d", weight, w.size-w.n) + } + + w.n += weight + w.cond.Broadcast() + + return nil +} + +func (w *Weighted) Size() int { + return w.size +} + +func (w *Weighted) Acquired() int { + w.cond.L.Lock() + defer w.cond.L.Unlock() + + return w.size - w.n +} + +func (w *Weighted) Available() int { + w.cond.L.Lock() + defer w.cond.L.Unlock() + + return w.n +} + +func (w *Weighted) Waiting() int { + w.cond.L.Lock() + defer w.cond.L.Unlock() + + return w.waiting +} diff --git a/internal/semaphore/semaphore_test.go b/internal/semaphore/semaphore_test.go new file mode 100644 index 0000000..6800954 --- /dev/null +++ b/internal/semaphore/semaphore_test.go @@ -0,0 +1,192 @@ +package semaphore + +import ( + "context" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/alitto/pond/v2/internal/assert" +) + +func TestWeighted(t *testing.T) { + sem := NewWeighted(context.Background(), 10) + + // Acquire 5 + err := sem.Acquire(5) + assert.Equal(t, nil, err) + + // Acquire 4 + err = sem.Acquire(4) + assert.Equal(t, nil, err) + + // Try to acquire 2 + assert.Equal(t, false, sem.TryAcquire(2)) + + // Try to acquire 1 + assert.Equal(t, true, sem.TryAcquire(1)) + + // Release 7 + sem.Release(7) + + // Try to acquire 7 + assert.Equal(t, true, sem.TryAcquire(7)) +} + +func TestWeightedWithMoreAcquirersThanReleasers(t *testing.T) { + sem := NewWeighted(context.Background(), 6) + + goroutines := 12 + acquire := 2 + release := 5 + wg := sync.WaitGroup{} + acquireSuccessCount := atomic.Uint64{} + acquireFailCount := atomic.Uint64{} + + wg.Add(goroutines) + + // Launch goroutines that try to acquire the semaphore + for i := 0; i < goroutines; i++ { + go func() { + defer wg.Done() + + if err := sem.Acquire(acquire); err != nil { + acquireFailCount.Add(1) + } else { + acquireSuccessCount.Add(1) + } + + if sem.Acquired() >= release { + sem.Release(release) + } + }() + } + + // Wait for goroutines to finish + wg.Wait() + + assert.Equal(t, uint64(12), acquireSuccessCount.Load()) + assert.Equal(t, uint64(0), acquireFailCount.Load()) + assert.Equal(t, 4, sem.Acquired()) +} + +func TestWeightedAcquireWithInvalidWeights(t *testing.T) { + sem := NewWeighted(context.Background(), 10) + + // Acquire 0 + err := sem.Acquire(0) + assert.Equal(t, "semaphore: weight 0 cannot be negative or zero", err.Error()) + + // Try to acquire 0 + res := sem.TryAcquire(0) + assert.Equal(t, false, res) + + // Acquire -1 + err = sem.Acquire(-1) + assert.Equal(t, "semaphore: weight -1 cannot be negative or zero", err.Error()) + + // Try to acquire -1 + res = sem.TryAcquire(-1) + assert.Equal(t, false, res) + + // Acquire 11 + err = sem.Acquire(11) + assert.Equal(t, "semaphore: weight 11 is greater than semaphore size 10", err.Error()) + + // Try to acquire 11 + res = sem.TryAcquire(11) + assert.Equal(t, false, res) +} + +func TestWeightedReleaseWithInvalidWeights(t *testing.T) { + sem := NewWeighted(context.Background(), 10) + + // Release 0 + err := sem.Release(0) + assert.Equal(t, "semaphore: weight 0 cannot be negative or zero", err.Error()) + + // Release -1 + err = sem.Release(-1) + assert.Equal(t, "semaphore: weight -1 cannot be negative or zero", err.Error()) + + // Release 11 + err = sem.Release(11) + assert.Equal(t, "semaphore: weight 11 is greater than semaphore size 10", err.Error()) + + // Release 1 + err = sem.Release(1) + assert.Equal(t, "semaphore: trying to release more than acquired: 1 > 0", err.Error()) +} + +func TestWeightedWithContextCanceled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + sem := NewWeighted(ctx, 10) + + // Acquire the semaphore + err := sem.Acquire(5) + assert.Equal(t, nil, err) + + // Cancel the context + cancel() + + // Attempt to acquire the semaphore + err = sem.Acquire(5) + assert.Equal(t, context.Canceled, err) + + // Try to acquire the semaphore + assert.Equal(t, false, sem.TryAcquire(5)) +} + +func TestWeightedWithContextCanceledWhileWaiting(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + sem := NewWeighted(ctx, 10) + + writers := 30 + wg := sync.WaitGroup{} + wg.Add(writers) + + assert.Equal(t, 10, sem.Size()) + assert.Equal(t, 0, sem.Acquired()) + assert.Equal(t, 10, sem.Available()) + assert.Equal(t, 0, sem.Waiting()) + + // Acquire the semaphore more than the semaphore size + for i := 0; i < writers; i++ { + go func() { + defer wg.Done() + sem.Acquire(1) + }() + } + + // Wait until 10 goroutines are blocked + for sem.Acquired() < 10 { + time.Sleep(1 * time.Millisecond) + } + + assert.Equal(t, 10, sem.Acquired()) + assert.Equal(t, 0, sem.Available()) + + // Release 10 goroutines + err := sem.Release(10) + assert.Equal(t, nil, err) + + // Wait until 10 goroutines are blocked + for sem.Acquired() < 10 { + time.Sleep(1 * time.Millisecond) + } + + // Cancel the context + cancel() + + // Wait for goroutines to finish + wg.Wait() + + assert.Equal(t, 10, sem.Acquired()) + assert.Equal(t, 0, sem.Available()) + assert.Equal(t, 0, sem.Waiting()) + assert.Equal(t, context.Canceled, sem.Acquire(1)) + assert.Equal(t, false, sem.TryAcquire(1)) +} diff --git a/pool.go b/pool.go index b93dade..038369c 100644 --- a/pool.go +++ b/pool.go @@ -10,6 +10,7 @@ import ( "github.com/alitto/pond/v2/internal/dispatcher" "github.com/alitto/pond/v2/internal/future" + "github.com/alitto/pond/v2/internal/semaphore" ) var NUM_CPU = runtime.NumCPU() @@ -17,6 +18,7 @@ var NUM_CPU = runtime.NumCPU() var MAX_TASKS_CHAN_LENGTH = NUM_CPU * 128 var ErrPoolStopped = errors.New("pool stopped") +var ErrQueueFull = errors.New("task queue is full") var poolStoppedFuture = func() Task { future, resolve := future.NewFuture(context.Background()) @@ -24,6 +26,12 @@ var poolStoppedFuture = func() Task { return future }() +var poolQueueFullFuture = func() Task { + future, resolve := future.NewFuture(context.Background()) + resolve(ErrQueueFull) + return future +}() + // basePool is the base interface for all pool types. type basePool interface { // Returns the number of worker goroutines that are currently active (executing a task) in the pool. @@ -47,6 +55,14 @@ type basePool interface { // Returns the maximum concurrency of the pool. MaxConcurrency() int + // Returns the size of the task queue. + QueueSize() int + + // Returns true if the pool is non-blocking, meaning that it will not block when the task queue is full. + // In a non-blocking pool, tasks that cannot be submitted to the queue will be dropped. + // By default, pools are blocking, meaning that they will block when the task queue is full. + NonBlocking() bool + // Returns the context associated with this pool. Context() context.Context @@ -73,8 +89,8 @@ type Pool interface { // Submits a task to the pool and returns a future that can be used to wait for the task to complete. SubmitErr(task func() error) Task - // Creates a new subpool with the specified maximum concurrency. - NewSubpool(maxConcurrency int) Pool + // Creates a new subpool with the specified maximum concurrency and options. + NewSubpool(maxConcurrency int, options ...Option) Pool // Creates a new task group. NewGroup() TaskGroup @@ -96,6 +112,9 @@ type pool struct { dispatcherRunning sync.Mutex successfulTaskCount atomic.Uint64 failedTaskCount atomic.Uint64 + nonBlocking bool + queueSize int + queueSem *semaphore.Weighted } func (p *pool) Context() context.Context { @@ -110,6 +129,14 @@ func (p *pool) MaxConcurrency() int { return p.maxConcurrency } +func (p *pool) QueueSize() int { + return p.queueSize +} + +func (p *pool) NonBlocking() bool { + return p.nonBlocking +} + func (p *pool) RunningWorkers() int64 { return p.workerCount.Load() } @@ -163,6 +190,19 @@ func (p *pool) submit(task any) Task { wrapped := wrapTask[struct{}, func(error)](task, resolve) + if p.queueSem != nil { + if p.nonBlocking { + if !p.queueSem.TryAcquire(1) { + return poolQueueFullFuture + } + } else { + if err := p.queueSem.Acquire(1); err != nil { + resolve(err) + return future + } + } + } + if err := p.dispatcher.Write(wrapped); err != nil { return poolStoppedFuture } @@ -186,8 +226,8 @@ func (p *pool) StopAndWait() { p.Stop().Wait() } -func (p *pool) NewSubpool(maxConcurrency int) Pool { - return newSubpool(maxConcurrency, p.ctx, p) +func (p *pool) NewSubpool(maxConcurrency int, options ...Option) Pool { + return newSubpool(maxConcurrency, p.Context(), p, options...) } func (p *pool) NewGroup() TaskGroup { @@ -297,6 +337,11 @@ func (p *pool) worker() { return } + // We have a task to execute, release the semaphore since it is no longer in the queue + if p.queueSem != nil { + p.queueSem.Release(1) + } + // Execute task _, err := invokeTask[any](task) @@ -343,6 +388,10 @@ func newPool(maxConcurrency int, options ...Option) *pool { pool.ctx, pool.cancel = context.WithCancelCause(pool.ctx) + if pool.queueSize > 0 { + pool.queueSem = semaphore.NewWeighted(pool.ctx, pool.queueSize) + } + pool.dispatcher = dispatcher.NewDispatcher(pool.ctx, pool.dispatch, tasksLen) return pool diff --git a/pool_test.go b/pool_test.go index 8f6b273..5618089 100644 --- a/pool_test.go +++ b/pool_test.go @@ -195,3 +195,66 @@ func TestPoolStoppedAfterCancel(t *testing.T) { assert.Equal(t, ErrPoolStopped, err) } + +func TestPoolWithQueueSize(t *testing.T) { + + pool := NewPool(1, WithQueueSize(10)) + + assert.Equal(t, 10, pool.QueueSize()) + assert.Equal(t, false, pool.NonBlocking()) + + var taskCount int = 50 + + for i := 0; i < taskCount; i++ { + pool.Submit(func() { + time.Sleep(1 * time.Millisecond) + }) + } + + pool.Stop().Wait() + + assert.Equal(t, uint64(taskCount), pool.SubmittedTasks()) + assert.Equal(t, uint64(taskCount), pool.CompletedTasks()) +} + +func TestPoolWithQueueSizeAndNonBlocking(t *testing.T) { + + pool := NewPool(10, WithQueueSize(10), WithNonBlocking(true)) + + assert.Equal(t, 10, pool.QueueSize()) + assert.Equal(t, true, pool.NonBlocking()) + + taskStarted := make(chan struct{}, 10) + taskWait := make(chan struct{}) + + for i := 0; i < 10; i++ { + pool.Submit(func() { + taskStarted <- struct{}{} + <-taskWait + }) + } + + // Wait for 10 tasks to start + for i := 0; i < 10; i++ { + <-taskStarted + } + + assert.Equal(t, int64(10), pool.RunningWorkers()) + assert.Equal(t, uint64(10), pool.SubmittedTasks()) + assert.Equal(t, uint64(0), pool.WaitingTasks()) + + // Saturate the queue + for i := 0; i < 10; i++ { + pool.Submit(func() { + time.Sleep(10 * time.Millisecond) + }) + } + + // Submit a task that should be rejected + task := pool.Submit(func() {}) + // Unblock tasks + close(taskWait) + assert.Equal(t, ErrQueueFull, task.Wait()) + + pool.Stop().Wait() +} diff --git a/pooloptions.go b/pooloptions.go index 25478d9..a6a0f0e 100644 --- a/pooloptions.go +++ b/pooloptions.go @@ -12,3 +12,17 @@ func WithContext(ctx context.Context) Option { p.ctx = ctx } } + +// WithQueueSize sets the max number of elements that can be queued in the pool. +func WithQueueSize(size int) Option { + return func(p *pool) { + p.queueSize = size + } +} + +// WithNonBlocking sets the pool to be non-blocking when the queue is full. +func WithNonBlocking(nonBlocking bool) Option { + return func(p *pool) { + p.nonBlocking = nonBlocking + } +} diff --git a/result.go b/result.go index 9d2a7da..87a19de 100644 --- a/result.go +++ b/result.go @@ -16,8 +16,8 @@ type ResultPool[R any] interface { // Submits a task to the pool and returns a future that can be used to wait for the task to complete and get the result. SubmitErr(task func() (R, error)) Result[R] - // Creates a new subpool with the specified maximum concurrency. - NewSubpool(maxConcurrency int) ResultPool[R] + // Creates a new subpool with the specified maximum concurrency and options. + NewSubpool(maxConcurrency int, options ...Option) ResultPool[R] // Creates a new task group. NewGroup() ResultTaskGroup[R] @@ -56,8 +56,8 @@ func (p *resultPool[R]) submit(task any) Result[R] { return future } -func (p *resultPool[R]) NewSubpool(maxConcurrency int) ResultPool[R] { - return newResultSubpool[R](maxConcurrency, p.Context(), p.pool) +func (p *resultPool[R]) NewSubpool(maxConcurrency int, options ...Option) ResultPool[R] { + return newResultSubpool[R](maxConcurrency, p.Context(), p.pool, options...) } func NewResultPool[R any](maxConcurrency int, options ...Option) ResultPool[R] { diff --git a/resultsubpool.go b/resultsubpool.go index 86ea28a..50b37c6 100644 --- a/resultsubpool.go +++ b/resultsubpool.go @@ -7,16 +7,17 @@ import ( "sync" "github.com/alitto/pond/v2/internal/dispatcher" + "github.com/alitto/pond/v2/internal/semaphore" ) type resultSubpool[R any] struct { *resultPool[R] parent *pool waitGroup sync.WaitGroup - sem chan struct{} + sem *semaphore.Weighted } -func newResultSubpool[R any](maxConcurrency int, ctx context.Context, parent *pool) ResultPool[R] { +func newResultSubpool[R any](maxConcurrency int, ctx context.Context, parent *pool, options ...Option) ResultPool[R] { if maxConcurrency == 0 { maxConcurrency = parent.MaxConcurrency() @@ -35,15 +36,27 @@ func newResultSubpool[R any](maxConcurrency int, ctx context.Context, parent *po tasksLen = MAX_TASKS_CHAN_LENGTH } - subpool := &resultSubpool[R]{ - resultPool: &resultPool[R]{ - pool: &pool{ - ctx: ctx, - maxConcurrency: maxConcurrency, - }, + resultPool := &resultPool[R]{ + pool: &pool{ + ctx: ctx, + maxConcurrency: maxConcurrency, }, - parent: parent, - sem: make(chan struct{}, maxConcurrency), + } + + for _, option := range options { + option(resultPool.pool) + } + + ctx = resultPool.Context() + + if resultPool.pool.queueSize > 0 { + resultPool.pool.queueSem = semaphore.NewWeighted(ctx, resultPool.pool.queueSize) + } + + subpool := &resultSubpool[R]{ + resultPool: resultPool, + parent: parent, + sem: semaphore.NewWeighted(ctx, maxConcurrency), } subpool.pool.dispatcher = dispatcher.NewDispatcher(ctx, subpool.dispatch, tasksLen) @@ -53,27 +66,36 @@ func newResultSubpool[R any](maxConcurrency int, ctx context.Context, parent *po func (p *resultSubpool[R]) dispatch(incomingTasks []any) { - p.waitGroup.Add(len(incomingTasks)) - // Submit tasks for _, task := range incomingTasks { - select { - case <-p.Context().Done(): - // Context canceled, exit - return - case p.sem <- struct{}{}: - // Acquired the semaphore, submit another task + // Acquire semaphore to limit concurrency + if p.nonBlocking { + if ok := p.sem.TryAcquire(1); !ok { + // Context canceled, exit + return + } + } else { + if err := p.sem.Acquire(1); err != nil { + // Context canceled, exit + return + } } subpoolTask := subpoolTask[any]{ task: task, sem: p.sem, + queueSem: p.queueSem, waitGroup: &p.waitGroup, updateMetrics: p.updateMetrics, } - p.parent.Go(subpoolTask.Run) + p.waitGroup.Add(1) + + if err := p.parent.Go(subpoolTask.Run); err != nil { + // We failed to submit the task, release semaphore + subpoolTask.Close() + } } } @@ -82,11 +104,13 @@ func (p *resultSubpool[R]) Stop() Task { p.dispatcher.CloseAndWait() p.waitGroup.Wait() - - close(p.sem) }) } func (p *resultSubpool[R]) StopAndWait() { p.Stop().Wait() } + +func (p *resultSubpool[R]) RunningWorkers() int64 { + return int64(p.sem.Acquired()) +} diff --git a/subpool.go b/subpool.go index baf1bbc..93479ff 100644 --- a/subpool.go +++ b/subpool.go @@ -7,6 +7,7 @@ import ( "sync" "github.com/alitto/pond/v2/internal/dispatcher" + "github.com/alitto/pond/v2/internal/semaphore" ) // subpool is a pool that is a subpool of another pool @@ -14,10 +15,10 @@ type subpool struct { *pool parent *pool waitGroup sync.WaitGroup - sem chan struct{} + sem *semaphore.Weighted } -func newSubpool(maxConcurrency int, ctx context.Context, parent *pool) Pool { +func newSubpool(maxConcurrency int, ctx context.Context, parent *pool, options ...Option) Pool { if maxConcurrency == 0 { maxConcurrency = parent.MaxConcurrency() @@ -36,13 +37,25 @@ func newSubpool(maxConcurrency int, ctx context.Context, parent *pool) Pool { tasksLen = MAX_TASKS_CHAN_LENGTH } + pool := &pool{ + ctx: ctx, + maxConcurrency: maxConcurrency, + } + + for _, option := range options { + option(pool) + } + + ctx = pool.Context() + + if pool.queueSize > 0 { + pool.queueSem = semaphore.NewWeighted(ctx, pool.queueSize) + } + subpool := &subpool{ - pool: &pool{ - ctx: ctx, - maxConcurrency: maxConcurrency, - }, + pool: pool, parent: parent, - sem: make(chan struct{}, maxConcurrency), + sem: semaphore.NewWeighted(ctx, maxConcurrency), } subpool.pool.dispatcher = dispatcher.NewDispatcher(ctx, subpool.dispatch, tasksLen) @@ -51,28 +64,34 @@ func newSubpool(maxConcurrency int, ctx context.Context, parent *pool) Pool { } func (p *subpool) dispatch(incomingTasks []any) { - - p.waitGroup.Add(len(incomingTasks)) - // Submit tasks for _, task := range incomingTasks { - select { - case <-p.Context().Done(): - // Context canceled, exit - return - case p.sem <- struct{}{}: - // Acquired the semaphore, submit another task + // Acquire semaphore to limit concurrency + if p.nonBlocking { + if ok := p.sem.TryAcquire(1); !ok { + return + } + } else { + if err := p.sem.Acquire(1); err != nil { + return + } } subpoolTask := subpoolTask[any]{ task: task, + queueSem: p.queueSem, sem: p.sem, waitGroup: &p.waitGroup, updateMetrics: p.updateMetrics, } - p.parent.Go(subpoolTask.Run) + p.waitGroup.Add(1) + + if err := p.parent.Go(subpoolTask.Run); err != nil { + // We failed to submit the task, release semaphore + subpoolTask.Close() + } } } @@ -81,11 +100,13 @@ func (p *subpool) Stop() Task { p.dispatcher.CloseAndWait() p.waitGroup.Wait() - - close(p.sem) }) } func (p *subpool) StopAndWait() { p.Stop().Wait() } + +func (p *subpool) RunningWorkers() int64 { + return int64(p.sem.Acquired()) +} diff --git a/subpool_test.go b/subpool_test.go index 29c5d4e..58ed2ea 100644 --- a/subpool_test.go +++ b/subpool_test.go @@ -191,3 +191,111 @@ func TestSubpoolStoppedAfterCancel(t *testing.T) { assert.Equal(t, ErrPoolStopped, err) } + +func TestSubpoolWithDifferentLimits(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + pool := NewPool(7, WithContext(ctx)) + + subpool1 := pool.NewSubpool(1) + subpool2 := pool.NewSubpool(2) + subpool3 := pool.NewSubpool(3) + + taskStarted := make(chan struct{}, 10) + taskWait := make(chan struct{}) + + var task = func() func() { + return func() { + taskStarted <- struct{}{} + <-taskWait + } + } + + // Submit tasks to subpool1 and wait for 1 task to start + for i := 0; i < 10; i++ { + subpool1.Submit(task()) + } + <-taskStarted + + // Submit tasks to subpool2 and wait for 2 tasks to start + for i := 0; i < 10; i++ { + subpool2.Submit(task()) + } + <-taskStarted + <-taskStarted + + // Submit tasks to subpool3 and wait for 3 tasks to start + for i := 0; i < 10; i++ { + subpool3.Submit(task()) + } + <-taskStarted + <-taskStarted + <-taskStarted + + // Submit tasks to the main pool and wait for 1 to start + for i := 0; i < 10; i++ { + pool.Submit(task()) + } + <-taskStarted + + // Verify concurrency of each pool + assert.Equal(t, int64(1), subpool1.RunningWorkers()) + assert.Equal(t, int64(2), subpool2.RunningWorkers()) + assert.Equal(t, int64(3), subpool3.RunningWorkers()) + assert.Equal(t, int64(7), pool.RunningWorkers()) + + assert.Equal(t, uint64(0), subpool1.CompletedTasks()) + assert.Equal(t, uint64(0), subpool2.CompletedTasks()) + assert.Equal(t, uint64(0), subpool3.CompletedTasks()) + assert.Equal(t, uint64(0), pool.CompletedTasks()) + + // Cancel the context to abort pending tasks + cancel() + + // Unblock all running tasks + close(taskWait) + + subpool1.StopAndWait() + subpool2.StopAndWait() + subpool3.StopAndWait() + pool.StopAndWait() + + assert.Equal(t, uint64(1), subpool1.CompletedTasks()) + assert.Equal(t, uint64(2), subpool2.CompletedTasks()) + assert.Equal(t, uint64(3), subpool3.CompletedTasks()) + assert.Equal(t, uint64(7), pool.CompletedTasks()) +} + +func TestSubpoolWithQueueSizeOverride(t *testing.T) { + pool := NewPool(10, WithQueueSize(10)) + + subpool := pool.NewSubpool(1, WithQueueSize(2), WithNonBlocking(true)) + + taskStarted := make(chan struct{}, 10) + taskWait := make(chan struct{}) + + var task = func() func() { + return func() { + taskStarted <- struct{}{} + <-taskWait + } + } + + // Submit tasks to subpool and wait for it to start + subpool.Submit(task()) + <-taskStarted + + // Submit more tasks to fill up the queue + for i := 0; i < 10; i++ { + subpool.Submit(task()) + } + + // 7 tasks should have been discarded + assert.Equal(t, int64(1), subpool.RunningWorkers()) + assert.Equal(t, uint64(3), subpool.SubmittedTasks()) + + // Unblock all running tasks + close(taskWait) + + subpool.StopAndWait() + pool.StopAndWait() +} diff --git a/task.go b/task.go index 6814073..518461b 100644 --- a/task.go +++ b/task.go @@ -4,24 +4,27 @@ import ( "errors" "fmt" "sync" + + "github.com/alitto/pond/v2/internal/semaphore" ) var ErrPanic = errors.New("task panicked") type subpoolTask[R any] struct { task any - sem chan struct{} + queueSem *semaphore.Weighted + sem *semaphore.Weighted waitGroup *sync.WaitGroup updateMetrics func(error) } func (t subpoolTask[R]) Run() { - defer func() { - // Release semaphore - <-t.sem - // Decrement wait group - t.waitGroup.Done() - }() + defer t.Close() + + // Release task queue semaphore when task is pulled from queue + if t.queueSem != nil { + t.queueSem.Release(1) + } _, err := invokeTask[R](t.task) @@ -30,6 +33,14 @@ func (t subpoolTask[R]) Run() { } } +func (t subpoolTask[R]) Close() { + // Release semaphore + t.sem.Release(1) + + // Decrement wait group + t.waitGroup.Done() +} + type wrappedTask[R any, C func(error) | func(R, error)] struct { task any callback C