diff --git a/client/v3/concurrency/election.go b/client/v3/concurrency/election.go index ac1303dd8b4..bd520023018 100644 --- a/client/v3/concurrency/election.go +++ b/client/v3/concurrency/election.go @@ -90,7 +90,7 @@ func (e *Election) Campaign(ctx context.Context, val string) error { } } - err = waitDeletes(ctx, client, e.keyPrefix, e.leaderRev-1) + err = waitDeletes(ctx, client, e.keyPrefix, e.leaderKey, e.leaderRev-1) if err != nil { // clean up in case of context cancel select { diff --git a/client/v3/concurrency/key.go b/client/v3/concurrency/key.go index 92e365c4715..0912e938d62 100644 --- a/client/v3/concurrency/key.go +++ b/client/v3/concurrency/key.go @@ -22,31 +22,63 @@ import ( v3 "go.etcd.io/etcd/client/v3" ) -func waitDelete(ctx context.Context, client *v3.Client, key string, rev int64) error { +var ( + ErrLostWatcher = errors.New("lost watcher waiting for delete") + ErrSessionExpiredDuringWait = errors.New("session expired during wait") +) + +func waitDelete(ctx context.Context, client *v3.Client, key, sessionKey string, rev int64) error { cctx, cancel := context.WithCancel(ctx) defer cancel() - var wr v3.WatchResponse wch := client.Watch(cctx, key, v3.WithRev(rev)) - for wr = range wch { - for _, ev := range wr.Events { - if ev.Type == mvccpb.DELETE { - return nil + sch := client.Watch(cctx, sessionKey) + + for { + select { + case wr, ok := <-wch: + if !ok { + if err := wr.Err(); err != nil { + return err + } + return ErrLostWatcher + } + + if err := wr.Err(); err != nil { + return err + } + + for _, ev := range wr.Events { + if ev.Type == mvccpb.DELETE { + return nil + } } + case sr, ok := <-sch: + if !ok { + if err := sr.Err(); err != nil { + return err + } + return ErrLostWatcher + } + + if err := sr.Err(); err != nil { + return err + } + + for _, ev := range sr.Events { + if ev.Type == mvccpb.DELETE { + return ErrSessionExpiredDuringWait + } + } + case <-ctx.Done(): + return ctx.Err() } } - if err := wr.Err(); err != nil { - return err - } - if err := ctx.Err(); err != nil { - return err - } - return errors.New("lost watcher waiting for delete") } // waitDeletes efficiently waits until all keys matching the prefix and no greater // than the create revision are deleted. -func waitDeletes(ctx context.Context, client *v3.Client, pfx string, maxCreateRev int64) error { +func waitDeletes(ctx context.Context, client *v3.Client, pfx, sessionKey string, maxCreateRev int64) error { getOpts := append(v3.WithLastCreate(), v3.WithMaxCreateRev(maxCreateRev)) for { resp, err := client.Get(ctx, pfx, getOpts...) @@ -57,7 +89,7 @@ func waitDeletes(ctx context.Context, client *v3.Client, pfx string, maxCreateRe return nil } lastKey := string(resp.Kvs[0].Key) - if err = waitDelete(ctx, client, lastKey, resp.Header.Revision); err != nil { + if err = waitDelete(ctx, client, lastKey, sessionKey, resp.Header.Revision); err != nil { return err } } diff --git a/client/v3/concurrency/mutex.go b/client/v3/concurrency/mutex.go index 6898bbcec41..9ebe054c543 100644 --- a/client/v3/concurrency/mutex.go +++ b/client/v3/concurrency/mutex.go @@ -85,8 +85,7 @@ func (m *Mutex) Lock(ctx context.Context) error { } client := m.s.Client() // wait for deletion revisions prior to myKey - // TODO: early termination if the session key is deleted before other session keys with smaller revisions. - werr := waitDeletes(ctx, client, m.pfx, m.myRev-1) + werr := waitDeletes(ctx, client, m.pfx, m.myKey, m.myRev-1) // release lock key if wait failed if werr != nil { m.Unlock(client.Ctx())