Skip to content

Commit

Permalink
fix(taskgroups): do not start remaining tasks after context cancellat…
Browse files Browse the repository at this point in the history
…ion (#77)
  • Loading branch information
alitto authored Oct 24, 2024
1 parent ad81b48 commit 1e28250
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 34 deletions.
87 changes: 53 additions & 34 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,6 @@ task := pool.SubmitErr(func() error {

// Wait for the task to complete and get the error
err := task.Wait()

if err != nil {
fmt.Printf("Failed to run task: %v", err)
} else {
fmt.Println("Task completed successfully")
}
```

### Submitting tasks that return results
Expand All @@ -117,12 +111,7 @@ task := pool.Submit(func() (string) {

// Wait for the task to complete and get the result
result, err := task.Wait()

if err != nil {
fmt.Printf("Failed to run task: %v", err)
} else {
fmt.Printf("Task result: %v", result)
}
// result = "Hello, World!" and err = nil
```

### Submitting tasks that return results or errors
Expand All @@ -140,12 +129,28 @@ task := pool.SubmitErr(func() (string, error) {

// Wait for the task to complete and get the result
result, err := task.Wait()
// result = "Hello, World!" and err = nil
```

if err != nil {
fmt.Printf("Failed to run task: %v", err)
} else {
fmt.Printf("Task result: %v", result)
}
### Submitting tasks associated with a context

If you need to submit a task that is associated with a context, you can pass the context directly to the task function.

``` go
// Create a pool with limited concurrency
pool := pond.NewPool(10)

// Create a context that can be cancelled
ctx, cancel := context.WithCancel(context.Background())

// Submit a task that is associated with a context
task := pool.SubmitErr(func() error {
return doSomethingWithCtx(ctx) // Pass the context to the task directly
})

// Wait for the task to complete and get the error.
// If the context is cancelled, the task is stopped and an error is returned.
err := task.Wait()
```

### Submitting a group of related tasks
Expand All @@ -169,11 +174,6 @@ for i := 0; i < 20; i++ {

// Wait for all tasks in the group to complete
err := group.Wait()
if err != nil {
fmt.Printf("Failed to complete group tasks: %v", err)
} else {
fmt.Println("Successfully completed all group tasks")
}
```

### Submitting a group of related tasks and waiting for the first error
Expand Down Expand Up @@ -201,12 +201,6 @@ for i := 0; i < 20; i++ {

// Wait for all tasks in the group to complete or the first error to occur
err := group.Wait()

if err != nil {
fmt.Printf("Failed to complete group tasks: %v", err)
} else {
fmt.Println("Successfully completed all group tasks")
}
```

### Submitting a group of related tasks that return results
Expand All @@ -230,19 +224,44 @@ for i := 0; i < 20; i++ {

// Wait for all tasks in the group to complete
results, err := group.Wait()
// results = ["Task #0", "Task #1", ..., "Task #19"]
// results = ["Task #0", "Task #1", ..., "Task #19"] and err = nil
```

if err != nil {
fmt.Printf("Failed to complete group tasks: %v", err)
} else {
fmt.Printf("Successfully completed all group tasks: %v", results)
### Submitting a group of tasks associated with a context

If you need to submit a group of tasks that are associated with a context, you can pass the context directly to the task function.
Just make sure to handle the context in the task function to stop the task when the context is cancelled.

``` go
// Create a pool with limited concurrency
pool := pond.NewPool(10)

// Create a context that can be cancelled
ctx, cancel := context.WithCancel(context.Background())

// Create a task group
group := pool.NewGroup()

// Submit a group of tasks
for i := 0; i < 20; i++ {
i := i
group.SubmitErr(func() error {
return doSomethingWithCtx(ctx) // Pass the context to the task directly
})
}

// Wait for all tasks in the group to complete.
// If the context is cancelled, all tasks are stopped and the first error is returned.
err := group.Wait()
```

### Using a custom Context
### Using a custom Context at the pool level

Each pool is associated with a context that is used to stop all workers when the pool is stopped. By default, the context is the background context (`context.Background()`). You can create a custom context and pass it to the pool to stop all workers when the context is cancelled.

> [!NOTE]
> The context passed to a pool with `pond.WithContext` is meant to be used to stop the pool and not to stop individual tasks. If you need to stop individual tasks, you should pass the context directly to the task function and handle it accordingly. See [Submitting tasks associated with a context](#submitting-tasks-associated-with-a-context) and [Submitting a group of tasks associated with a context](#submitting-a-group-of-tasks-associated-with-a-context).
```go
// Create a custom context that can be cancelled
customCtx, cancel := context.WithCancel(context.Background())
Expand Down
9 changes: 9 additions & 0 deletions group.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,15 @@ func (g *abstractTaskGroup[T, E, O]) submit(task any) {
index := int(g.nextIndex.Add(1) - 1)

err := g.pool.Go(func() {
// Check if the context has been cancelled to prevent running tasks that are not needed
if err := g.future.Context().Err(); err != nil {
g.futureResolver(index, &result[O]{
Err: err,
}, err)
return
}

// Invoke the task
output, err := invokeTask[O](task)

g.futureResolver(index, &result[O]{
Expand Down
55 changes: 55 additions & 0 deletions group_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package pond

import (
"context"
"errors"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -116,6 +118,31 @@ func TestTaskGroupWithStoppedPool(t *testing.T) {
assert.Equal(t, ErrPoolStopped, err)
}

func TestTaskGroupWithContextCanceled(t *testing.T) {

pool := NewPool(100)

group := pool.NewGroup()

ctx, cancel := context.WithCancel(context.Background())

go func() {
time.Sleep(10 * time.Millisecond)
cancel()
}()

err := group.SubmitErr(func() error {
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(100 * time.Millisecond):
return nil
}
}).Wait()

assert.Equal(t, context.Canceled, err)
}

func TestTaskGroupWithNoTasks(t *testing.T) {

group := NewResultPool[int](10).
Expand All @@ -128,3 +155,31 @@ func TestTaskGroupWithNoTasks(t *testing.T) {
group.SubmitErr()
})
}

func TestTaskGroupCanceledShouldSkipRemainingTasks(t *testing.T) {

pool := NewPool(1)

group := pool.NewGroup()

var executedCount atomic.Int32
sampleErr := errors.New("sample error")

group.Submit(func() {
executedCount.Add(1)
})

group.SubmitErr(func() error {
time.Sleep(10 * time.Millisecond)
return sampleErr
})

group.Submit(func() {
executedCount.Add(1)
})

err := group.Wait()

assert.Equal(t, sampleErr, err)
assert.Equal(t, int32(1), executedCount.Load())
}
4 changes: 4 additions & 0 deletions internal/future/composite.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ func NewCompositeFuture[V any](ctx context.Context) (*CompositeFuture[V], Compos
return future, future.resolve
}

func (f *CompositeFuture[V]) Context() context.Context {
return f.ctx
}

func (f *CompositeFuture[V]) Wait(count int) ([]V, error) {
f.mutex.Lock()
defer f.mutex.Unlock()
Expand Down

0 comments on commit 1e28250

Please sign in to comment.