Skip to content

Commit 8f9ea31

Browse files
committed
Add context for Watch.
1 parent 0dcfcd1 commit 8f9ea31

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

pkg/client/client.go

+3-1
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ func (c *Client) handleNatsMsg(msg *nats.Msg, callback NodeStateChangeCallback)
113113
return nil
114114
}
115115

116-
func (c *Client) Watch(service string, handleNodeState NodeStateChangeCallback) error {
116+
func (c *Client) Watch(ctx context.Context, service string, handleNodeState NodeStateChangeCallback) error {
117117
if handleNodeState == nil {
118118
err := fmt.Errorf("Watch callback must be set for %v", service)
119119
logger.Warnf("Watch: err => %v", err)
@@ -138,6 +138,8 @@ func (c *Client) Watch(service string, handleNodeState NodeStateChangeCallback)
138138
select {
139139
case <-c.ctx.Done():
140140
return c.ctx.Err()
141+
case <-ctx.Done():
142+
return ctx.Err()
141143
case msg, ok := <-msgCh:
142144
if ok {
143145
err := c.handleNatsMsg(msg, handleNodeState)

pkg/client/client_test.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package client
22

33
import (
4+
"context"
45
"sync"
56
"testing"
67

@@ -51,7 +52,7 @@ func TestWatch(t *testing.T) {
5152
ExtraInfo: extraInfo,
5253
}
5354

54-
s.Watch("sfu", func(state discovery.NodeState, n *discovery.Node) {
55+
s.Watch(context.Background(), "sfu", func(state discovery.NodeState, n *discovery.Node) {
5556
if state == discovery.NodeUp {
5657
log.Infof("NodeUp => %v", *n)
5758
assert.Equal(t, node, *n)

0 commit comments

Comments
 (0)