@@ -6,7 +6,10 @@ package handlers
6
6
7
7
import (
8
8
"context"
9
+ "errors"
10
+ "sync/atomic"
9
11
"testing"
12
+ "time"
10
13
11
14
"github.com/stretchr/testify/require"
12
15
@@ -25,8 +28,15 @@ import (
25
28
)
26
29
27
30
type mockUpgradeManager struct {
28
- msgChan chan string
29
- completedChan chan struct {}
31
+ UpgradeFn func (
32
+ ctx context.Context ,
33
+ version string ,
34
+ sourceURI string ,
35
+ action * fleetapi.ActionUpgrade ,
36
+ details * details.Details ,
37
+ skipVerifyOverride bool ,
38
+ skipDefaultPgp bool ,
39
+ pgpBytes ... string ) (reexec.ShutdownCallbackFn , error )
30
40
}
31
41
32
42
func (u * mockUpgradeManager ) Upgradeable () bool {
@@ -37,15 +47,25 @@ func (u *mockUpgradeManager) Reload(rawConfig *config.Config) error {
37
47
return nil
38
48
}
39
49
40
- func (u * mockUpgradeManager ) Upgrade (ctx context.Context , version string , sourceURI string , action * fleetapi.ActionUpgrade , details * details.Details , skipVerifyOverride bool , skipDefaultPgp bool , pgpBytes ... string ) (_ reexec.ShutdownCallbackFn , err error ) {
41
- select {
42
- case <- u .completedChan :
43
- u .msgChan <- "completed " + version
44
- return nil , nil
45
- case <- ctx .Done ():
46
- u .msgChan <- "canceled " + version
47
- return nil , ctx .Err ()
48
- }
50
+ func (u * mockUpgradeManager ) Upgrade (
51
+ ctx context.Context ,
52
+ version string ,
53
+ sourceURI string ,
54
+ action * fleetapi.ActionUpgrade ,
55
+ details * details.Details ,
56
+ skipVerifyOverride bool ,
57
+ skipDefaultPgp bool ,
58
+ pgpBytes ... string ) (reexec.ShutdownCallbackFn , error ) {
59
+
60
+ return u .UpgradeFn (
61
+ ctx ,
62
+ version ,
63
+ sourceURI ,
64
+ action ,
65
+ details ,
66
+ skipVerifyOverride ,
67
+ skipDefaultPgp ,
68
+ pgpBytes ... )
49
69
}
50
70
51
71
func (u * mockUpgradeManager ) Ack (ctx context.Context , acker acker.Acker ) error {
@@ -65,8 +85,7 @@ func TestUpgradeHandler(t *testing.T) {
65
85
log , _ := logger .New ("" , false )
66
86
67
87
agentInfo := & info.AgentInfo {}
68
- msgChan := make (chan string )
69
- completedChan := make (chan struct {})
88
+ upgradeCalledChan := make (chan struct {})
70
89
71
90
// Create and start the coordinator
72
91
c := coordinator .New (
@@ -76,7 +95,21 @@ func TestUpgradeHandler(t *testing.T) {
76
95
agentInfo ,
77
96
component.RuntimeSpecs {},
78
97
nil ,
79
- & mockUpgradeManager {msgChan : msgChan , completedChan : completedChan },
98
+ & mockUpgradeManager {
99
+ UpgradeFn : func (
100
+ ctx context.Context ,
101
+ version string ,
102
+ sourceURI string ,
103
+ action * fleetapi.ActionUpgrade ,
104
+ details * details.Details ,
105
+ skipVerifyOverride bool ,
106
+ skipDefaultPgp bool ,
107
+ pgpBytes ... string ) (reexec.ShutdownCallbackFn , error ) {
108
+
109
+ upgradeCalledChan <- struct {}{}
110
+ return nil , nil
111
+ },
112
+ },
80
113
nil , nil , nil , nil , nil , false )
81
114
//nolint:errcheck // We don't need the termination state of the Coordinator
82
115
go c .Run (ctx )
@@ -86,11 +119,14 @@ func TestUpgradeHandler(t *testing.T) {
86
119
Version : "8.3.0" , SourceURI : "http://localhost" }}
87
120
ack := noopacker .New ()
88
121
err := u .Handle (ctx , & a , ack )
89
- // indicate that upgrade is completed
90
- close (completedChan )
91
122
require .NoError (t , err )
92
- msg := <- msgChan
93
- require .Equal (t , "completed 8.3.0" , msg )
123
+
124
+ // Make sure this test does not dead lock or wait for too long
125
+ select {
126
+ case <- time .Tick (50 * time .Millisecond ):
127
+ t .Fatal ("mockUpgradeManager.Upgrade was not called" )
128
+ case <- upgradeCalledChan :
129
+ }
94
130
}
95
131
96
132
func TestUpgradeHandlerSameVersion (t * testing.T ) {
@@ -102,18 +138,37 @@ func TestUpgradeHandlerSameVersion(t *testing.T) {
102
138
log , _ := logger .New ("" , false )
103
139
104
140
agentInfo := & info.AgentInfo {}
105
- msgChan := make (chan string )
106
- completedChan := make (chan struct {})
141
+ upgradeCalledChan := make (chan struct {})
107
142
108
143
// Create and start the Coordinator
144
+ upgradeCalled := atomic.Bool {}
109
145
c := coordinator .New (
110
146
log ,
111
147
configuration .DefaultConfiguration (),
112
148
logger .DefaultLogLevel ,
113
149
agentInfo ,
114
150
component.RuntimeSpecs {},
115
151
nil ,
116
- & mockUpgradeManager {msgChan : msgChan , completedChan : completedChan },
152
+ & mockUpgradeManager {
153
+ UpgradeFn : func (
154
+ ctx context.Context ,
155
+ version string ,
156
+ sourceURI string ,
157
+ action * fleetapi.ActionUpgrade ,
158
+ details * details.Details ,
159
+ skipVerifyOverride bool ,
160
+ skipDefaultPgp bool ,
161
+ pgpBytes ... string ) (reexec.ShutdownCallbackFn , error ) {
162
+
163
+ if upgradeCalled .CompareAndSwap (false , true ) {
164
+ upgradeCalledChan <- struct {}{}
165
+ return nil , nil
166
+ }
167
+ err := errors .New ("mockUpgradeManager.Upgrade called more than once" )
168
+ t .Error (err .Error ())
169
+ return nil , err
170
+ },
171
+ },
117
172
nil , nil , nil , nil , nil , false )
118
173
//nolint:errcheck // We don't need the termination state of the Coordinator
119
174
go c .Run (ctx )
@@ -126,10 +181,13 @@ func TestUpgradeHandlerSameVersion(t *testing.T) {
126
181
err2 := u .Handle (ctx , & a , ack )
127
182
require .NoError (t , err1 )
128
183
require .NoError (t , err2 )
129
- // indicate that upgrade is completed
130
- close (completedChan )
131
- msg := <- msgChan
132
- require .Equal (t , "completed 8.3.0" , msg )
184
+
185
+ // Make sure this test does not dead lock or wait for too long
186
+ select {
187
+ case <- time .Tick (50 * time .Millisecond ):
188
+ t .Fatal ("mockUpgradeManager.Upgrade was not called" )
189
+ case <- upgradeCalledChan :
190
+ }
133
191
}
134
192
135
193
func TestUpgradeHandlerNewVersion (t * testing.T ) {
@@ -139,10 +197,9 @@ func TestUpgradeHandlerNewVersion(t *testing.T) {
139
197
defer cancel ()
140
198
141
199
log , _ := logger .New ("" , false )
200
+ upgradeCalledChan := make (chan string )
142
201
143
202
agentInfo := & info.AgentInfo {}
144
- msgChan := make (chan string )
145
- completedChan := make (chan struct {})
146
203
147
204
// Create and start the Coordinator
148
205
c := coordinator .New (
@@ -152,7 +209,27 @@ func TestUpgradeHandlerNewVersion(t *testing.T) {
152
209
agentInfo ,
153
210
component.RuntimeSpecs {},
154
211
nil ,
155
- & mockUpgradeManager {msgChan : msgChan , completedChan : completedChan },
212
+ & mockUpgradeManager {
213
+ UpgradeFn : func (
214
+ ctx context.Context ,
215
+ version string ,
216
+ sourceURI string ,
217
+ action * fleetapi.ActionUpgrade ,
218
+ details * details.Details ,
219
+ skipVerifyOverride bool ,
220
+ skipDefaultPgp bool ,
221
+ pgpBytes ... string ) (reexec.ShutdownCallbackFn , error ) {
222
+
223
+ defer func () {
224
+ upgradeCalledChan <- version
225
+ }()
226
+ if version == "8.2.0" {
227
+ return nil , errors .New ("upgrade to 8.2.0 will always fail" )
228
+ }
229
+
230
+ return nil , nil
231
+ },
232
+ },
156
233
nil , nil , nil , nil , nil , false )
157
234
//nolint:errcheck // We don't need the termination state of the Coordinator
158
235
go c .Run (ctx )
@@ -163,14 +240,25 @@ func TestUpgradeHandlerNewVersion(t *testing.T) {
163
240
a2 := fleetapi.ActionUpgrade {Data : fleetapi.ActionUpgradeData {
164
241
Version : "8.5.0" , SourceURI : "http://localhost" }}
165
242
ack := noopacker .New ()
243
+
244
+ checkMsg := func (c <- chan string , expected , errMsg string ) {
245
+ t .Helper ()
246
+ // Make sure this test does not dead lock or wait for too long
247
+ // For some reason < 1s sometimes makes the test fail.
248
+ select {
249
+ case <- time .Tick (1300 * time .Millisecond ):
250
+ t .Fatal ("timed out waiting for Upgrade to return" )
251
+ case msg := <- c :
252
+ require .Equal (t , expected , msg , errMsg )
253
+ }
254
+ }
255
+
256
+ // Send both upgrade actions, a1 will error before a2 succeeds
166
257
err1 := u .Handle (ctx , & a1 , ack )
167
258
require .NoError (t , err1 )
259
+ checkMsg (upgradeCalledChan , "8.2.0" , "first call must be with version 8.2.0" )
260
+
168
261
err2 := u .Handle (ctx , & a2 , ack )
169
262
require .NoError (t , err2 )
170
- msg1 := <- msgChan
171
- require .Equal (t , "canceled 8.2.0" , msg1 )
172
- // indicate that upgrade is completed
173
- close (completedChan )
174
- msg2 := <- msgChan
175
- require .Equal (t , "completed 8.5.0" , msg2 )
263
+ checkMsg (upgradeCalledChan , "8.5.0" , "second call to Upgrade must be with version 8.5.0" )
176
264
}
0 commit comments