Skip to content

Commit

Permalink
feat(taskgroup): improve task group functionality (#81)
Browse files Browse the repository at this point in the history
  • Loading branch information
alitto authored Nov 9, 2024
1 parent 12d1782 commit 11ecf70
Show file tree
Hide file tree
Showing 12 changed files with 352 additions and 63 deletions.
31 changes: 27 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,32 @@ for i := 0; i < 20; i++ {
err := group.Wait()
```

### Submitting a group of related tasks associated with a context

You can submit a group of tasks that are linked to a context. This is useful when you need to execute a group of tasks concurrently and stop them when the context is cancelled (e.g. when the parent task is cancelled or times out).

``` go
// Create a pool with limited concurrency
pool := pond.NewPool(10)

// Create a context with a 5s timeout
timeout, _ := context.WithTimeout(context.Background(), 5*time.Second)

// Create a task group with a context
group := pool.NewGroupContext(timeout)

// Submit a group of tasks
for i := 0; i < 20; i++ {
i := i
group.Submit(func() {
fmt.Printf("Running group task #%d\n", i)
})
}

// Wait for all tasks in the group to complete or the timeout to occur, whichever comes first
err := group.Wait()
```

### Submitting a group of related tasks and waiting for the first error

You can submit a group of tasks that are related to each other and wait for the first error to occur. This is useful when you need to execute a group of tasks concurrently and stop the execution if an error occurs.
Expand Down Expand Up @@ -259,9 +285,6 @@ err := group.Wait()

Each pool is associated with a context that is used to stop all workers when the pool is stopped. By default, the context is the background context (`context.Background()`). You can create a custom context and pass it to the pool to stop all workers when the context is cancelled.

> [!NOTE]
> The context passed to a pool with `pond.WithContext` is meant to be used to stop the pool and not to stop individual tasks. If you need to stop individual tasks, you should pass the context directly to the task function and handle it accordingly. See [Submitting tasks associated with a context](#submitting-tasks-associated-with-a-context) and [Submitting a group of tasks associated with a context](#submitting-a-group-of-tasks-associated-with-a-context).
```go
// Create a custom context that can be cancelled
customCtx, cancel := context.WithCancel(context.Background())
Expand Down Expand Up @@ -391,7 +414,7 @@ If you are using pond v1, here are the changes you need to make to migrate to v2
- `pond.Strategy`: The pool now scales automatically based on the number of tasks submitted.
5. The `pool.StopAndWaitFor` method was deprecated. Use `pool.Stop().Done()` channel if you need to wait for the pool to stop in a select statement.
6. The `pool.Group` method was renamed to `pool.NewGroup`.
7. The `pool.GroupContext` method was deprecated. Use `pool.NewGroup` instead and pass the context directly in the inline task function.
7. The `pool.GroupContext` was renamed to `pool.NewGroupWithContext`.


## Examples
Expand Down
1 change: 0 additions & 1 deletion docs/strategies.svg

This file was deleted.

7 changes: 7 additions & 0 deletions examples/task_group_context/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
module github.com/alitto/pond/v2/examples/task_group_context

go 1.22

require github.com/alitto/pond/v2 v2.0.0

replace github.com/alitto/pond/v2 => ../../
49 changes: 49 additions & 0 deletions examples/task_group_context/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package main

import (
"context"
"fmt"
"time"

"github.com/alitto/pond/v2"
)

func main() {
// Generate 1000 tasks that each take 1 second to complete
tasks := generateTasks(1000, 1*time.Second)

// Create a pool with a max concurrency of 10
pool := pond.NewPool(10)
defer pool.StopAndWait()

// Create a context with a timeout of 5 seconds
timeout, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

// Create a group with the timeout context
group := pool.NewGroupContext(timeout)

// Submit all tasks to the group and wait for them to complete or the timeout to expire
err := group.Submit(tasks...).Wait()

if err != nil {
fmt.Printf("Group completed with error: %v\n", err)
} else {
fmt.Println("Group completed successfully")
}
}

func generateTasks(count int, duration time.Duration) []func() {

tasks := make([]func(), count)

for i := 0; i < count; i++ {
i := i
tasks[i] = func() {
time.Sleep(duration)
fmt.Printf("Task #%d finished\n", i)
}
}

return tasks
}
32 changes: 28 additions & 4 deletions group.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
package pond

import (
"context"
"errors"
"sync/atomic"

"github.com/alitto/pond/v2/internal/future"
)

var ErrGroupStopped = errors.New("task group stopped")

// TaskGroup represents a group of tasks that can be executed concurrently.
// The group can be waited on to block until all tasks have completed.
// If any of the tasks return an error, the group will return the first error encountered.
Expand All @@ -19,6 +23,12 @@ type TaskGroup interface {

// Waits for all tasks in the group to complete.
Wait() error

// Returns a channel that is closed when all tasks in the group have completed, a task returns an error, or the group is stopped.
Done() <-chan struct{}

// Stops the group and cancels all remaining tasks. Running tasks are not interrupted.
Stop()
}

// ResultTaskGroup represents a group of tasks that can be executed concurrently.
Expand All @@ -35,6 +45,12 @@ type ResultTaskGroup[O any] interface {

// Waits for all tasks in the group to complete.
Wait() ([]O, error)

// Returns a channel that is closed when all tasks in the group have completed, a task returns an error, or the group is stopped.
Done() <-chan struct{}

// Stops the group and cancels all remaining tasks. Running tasks are not interrupted.
Stop()
}

type result[O any] struct {
Expand All @@ -49,6 +65,14 @@ type abstractTaskGroup[T func() | func() O, E func() error | func() (O, error),
futureResolver future.CompositeFutureResolver[*result[O]]
}

func (g *abstractTaskGroup[T, E, O]) Done() <-chan struct{} {
return g.future.Done(int(g.nextIndex.Load()))
}

func (g *abstractTaskGroup[T, E, O]) Stop() {
g.future.Cancel(ErrGroupStopped)
}

func (g *abstractTaskGroup[T, E, O]) Submit(tasks ...T) *abstractTaskGroup[T, E, O] {
for _, task := range tasks {
g.submit(task)
Expand Down Expand Up @@ -142,8 +166,8 @@ func (g *resultTaskGroup[O]) Wait() ([]O, error) {
return values, err
}

func newTaskGroup(pool *pool) TaskGroup {
future, futureResolver := future.NewCompositeFuture[*result[struct{}]](pool.Context())
func newTaskGroup(pool *pool, ctx context.Context) TaskGroup {
future, futureResolver := future.NewCompositeFuture[*result[struct{}]](ctx)

return &taskGroup{
abstractTaskGroup: abstractTaskGroup[func(), func() error, struct{}]{
Expand All @@ -154,8 +178,8 @@ func newTaskGroup(pool *pool) TaskGroup {
}
}

func newResultTaskGroup[O any](pool *pool) ResultTaskGroup[O] {
future, futureResolver := future.NewCompositeFuture[*result[O]](pool.Context())
func newResultTaskGroup[O any](pool *pool, ctx context.Context) ResultTaskGroup[O] {
future, futureResolver := future.NewCompositeFuture[*result[O]](ctx)

return &resultTaskGroup[O]{
abstractTaskGroup: abstractTaskGroup[func() O, func() (O, error), O]{
Expand Down
71 changes: 71 additions & 0 deletions group_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,74 @@ func TestTaskGroupCanceledShouldSkipRemainingTasks(t *testing.T) {
assert.Equal(t, sampleErr, err)
assert.Equal(t, int32(1), executedCount.Load())
}

func TestTaskGroupWithCustomContext(t *testing.T) {
pool := NewPool(1)

ctx, cancel := context.WithCancel(context.Background())

group := pool.NewGroupContext(ctx)

var executedCount atomic.Int32

group.Submit(func() {
executedCount.Add(1)
})
group.Submit(func() {
executedCount.Add(1)
cancel()
})
group.Submit(func() {
executedCount.Add(1)
})

err := group.Wait()

assert.Equal(t, context.Canceled, err)
assert.Equal(t, struct{}{}, <-group.Done())
assert.Equal(t, int32(2), executedCount.Load())
}

func TestTaskGroupStop(t *testing.T) {
pool := NewPool(1)

group := pool.NewGroup()

var executedCount atomic.Int32

group.Submit(func() {
executedCount.Add(1)
})
group.Submit(func() {
executedCount.Add(1)
group.Stop()
})
group.Submit(func() {
executedCount.Add(1)
})

err := group.Wait()

assert.Equal(t, ErrGroupStopped, err)
assert.Equal(t, struct{}{}, <-group.Done())
assert.Equal(t, int32(2), executedCount.Load())
}

func TestTaskGroupDone(t *testing.T) {
pool := NewPool(10)

group := pool.NewGroup()

var executedCount atomic.Int32

for i := 0; i < 5; i++ {
group.Submit(func() {
time.Sleep(1 * time.Millisecond)
executedCount.Add(1)
})
}

<-group.Done()

assert.Equal(t, int32(5), executedCount.Load())
}
Loading

0 comments on commit 11ecf70

Please sign in to comment.