@@ -7,7 +7,9 @@ package handlers
7
7
import (
8
8
"context"
9
9
"errors"
10
+ "sync/atomic"
10
11
"testing"
12
+ "time"
11
13
12
14
"github.com/stretchr/testify/require"
13
15
@@ -27,9 +29,15 @@ import (
27
29
)
28
30
29
31
type mockUpgradeManager struct {
30
- msgChan chan string
31
- successChan chan string
32
- errChan chan error
32
+ UpgradeFn func (
33
+ ctx context.Context ,
34
+ version string ,
35
+ sourceURI string ,
36
+ action * fleetapi.ActionUpgrade ,
37
+ details * details.Details ,
38
+ skipVerifyOverride bool ,
39
+ skipDefaultPgp bool ,
40
+ pgpBytes ... string ) (reexec.ShutdownCallbackFn , error )
33
41
}
34
42
35
43
func (u * mockUpgradeManager ) Upgradeable () bool {
@@ -40,17 +48,25 @@ func (u *mockUpgradeManager) Reload(rawConfig *config.Config) error {
40
48
return nil
41
49
}
42
50
43
- func (u * mockUpgradeManager ) Upgrade (ctx context.Context , version string , sourceURI string , action * fleetapi.ActionUpgrade , details * details.Details , skipVerifyOverride bool , skipDefaultPgp bool , pgpBytes ... string ) (_ reexec.ShutdownCallbackFn , err error ) {
44
- select {
45
- case msg := <- u .successChan :
46
- u .msgChan <- msg
47
- return nil , nil
48
- case err := <- u .errChan :
49
- u .msgChan <- err .Error ()
50
- return nil , ctx .Err ()
51
- case <- ctx .Done ():
52
- return nil , ctx .Err ()
53
- }
51
+ func (u * mockUpgradeManager ) Upgrade (
52
+ ctx context.Context ,
53
+ version string ,
54
+ sourceURI string ,
55
+ action * fleetapi.ActionUpgrade ,
56
+ details * details.Details ,
57
+ skipVerifyOverride bool ,
58
+ skipDefaultPgp bool ,
59
+ pgpBytes ... string ) (reexec.ShutdownCallbackFn , error ) {
60
+
61
+ return u .UpgradeFn (
62
+ ctx ,
63
+ version ,
64
+ sourceURI ,
65
+ action ,
66
+ details ,
67
+ skipVerifyOverride ,
68
+ skipDefaultPgp ,
69
+ pgpBytes ... )
54
70
}
55
71
56
72
func (u * mockUpgradeManager ) Ack (ctx context.Context , acker acker.Acker ) error {
@@ -70,8 +86,7 @@ func TestUpgradeHandler(t *testing.T) {
70
86
log , _ := logger .New ("" , false )
71
87
72
88
agentInfo := & info.AgentInfo {}
73
- msgChan := make (chan string )
74
- completedChan := make (chan string )
89
+ upgradeCalledChan := make (chan struct {})
75
90
76
91
// Create and start the coordinator
77
92
c := coordinator .New (
@@ -81,7 +96,21 @@ func TestUpgradeHandler(t *testing.T) {
81
96
agentInfo ,
82
97
component.RuntimeSpecs {},
83
98
nil ,
84
- & mockUpgradeManager {msgChan : msgChan , successChan : completedChan },
99
+ & mockUpgradeManager {
100
+ UpgradeFn : func (
101
+ ctx context.Context ,
102
+ version string ,
103
+ sourceURI string ,
104
+ action * fleetapi.ActionUpgrade ,
105
+ details * details.Details ,
106
+ skipVerifyOverride bool ,
107
+ skipDefaultPgp bool ,
108
+ pgpBytes ... string ) (reexec.ShutdownCallbackFn , error ) {
109
+
110
+ upgradeCalledChan <- struct {}{}
111
+ return nil , nil
112
+ },
113
+ },
85
114
nil , nil , nil , nil , nil , false )
86
115
//nolint:errcheck // We don't need the termination state of the Coordinator
87
116
go c .Run (ctx )
@@ -91,11 +120,14 @@ func TestUpgradeHandler(t *testing.T) {
91
120
Version : "8.3.0" , SourceURI : "http://localhost" }}
92
121
ack := noopacker .New ()
93
122
err := u .Handle (ctx , & a , ack )
94
- // indicate that upgrade is completed
95
- close (completedChan )
96
123
require .NoError (t , err )
97
- msg := <- msgChan
98
- require .Equal (t , "completed 8.3.0" , msg )
124
+
125
+ // Make sure this test does not dead lock or wait for too long
126
+ select {
127
+ case <- time .Tick (50 * time .Millisecond ):
128
+ t .Fatal ("mockUpgradeManager.Upgrade was not called" )
129
+ case <- upgradeCalledChan :
130
+ }
99
131
}
100
132
101
133
func TestUpgradeHandlerSameVersion (t * testing.T ) {
@@ -109,19 +141,37 @@ func TestUpgradeHandlerSameVersion(t *testing.T) {
109
141
log , _ := logger .New ("" , false )
110
142
111
143
agentInfo := & info.AgentInfo {}
112
- msgChan := make (chan string )
113
- successChan := make (chan string )
114
- errChan := make (chan error )
144
+ upgradeCalledChan := make (chan struct {})
115
145
116
146
// Create and start the Coordinator
147
+ upgradeCalled := atomic.Bool {}
117
148
c := coordinator .New (
118
149
log ,
119
150
configuration .DefaultConfiguration (),
120
151
logger .DefaultLogLevel ,
121
152
agentInfo ,
122
153
component.RuntimeSpecs {},
123
154
nil ,
124
- & mockUpgradeManager {msgChan : msgChan , successChan : successChan , errChan : errChan },
155
+ & mockUpgradeManager {
156
+ UpgradeFn : func (
157
+ ctx context.Context ,
158
+ version string ,
159
+ sourceURI string ,
160
+ action * fleetapi.ActionUpgrade ,
161
+ details * details.Details ,
162
+ skipVerifyOverride bool ,
163
+ skipDefaultPgp bool ,
164
+ pgpBytes ... string ) (reexec.ShutdownCallbackFn , error ) {
165
+
166
+ if upgradeCalled .CompareAndSwap (false , true ) {
167
+ upgradeCalledChan <- struct {}{}
168
+ return nil , nil
169
+ }
170
+ err := errors .New ("mockUpgradeManager.Upgrade called more than once" )
171
+ t .Error (err .Error ())
172
+ return nil , err
173
+ },
174
+ },
125
175
nil , nil , nil , nil , nil , false )
126
176
//nolint:errcheck // We don't need the termination state of the Coordinator
127
177
go c .Run (ctx )
@@ -135,11 +185,12 @@ func TestUpgradeHandlerSameVersion(t *testing.T) {
135
185
require .NoError (t , err1 )
136
186
require .NoError (t , err2 )
137
187
138
- successChan <- "completed 8.3.0"
139
- require .Equal (t , "completed 8.3.0" , <- msgChan )
140
- errChan <- errors .New ("duplicated update, not finishing it?" )
141
- require .Equal (t , "duplicated update, not finishing it?" , <- msgChan )
142
-
188
+ // Make sure this test does not dead lock or wait for too long
189
+ select {
190
+ case <- time .Tick (50 * time .Millisecond ):
191
+ t .Fatal ("mockUpgradeManager.Upgrade was not called" )
192
+ case <- upgradeCalledChan :
193
+ }
143
194
}
144
195
145
196
func TestUpgradeHandlerNewVersion (t * testing.T ) {
@@ -149,11 +200,9 @@ func TestUpgradeHandlerNewVersion(t *testing.T) {
149
200
defer cancel ()
150
201
151
202
log , _ := logger .New ("" , false )
203
+ upgradeCalledChan := make (chan string )
152
204
153
205
agentInfo := & info.AgentInfo {}
154
- msgChan := make (chan string )
155
- completedChan := make (chan string )
156
- errorChan := make (chan error )
157
206
158
207
// Create and start the Coordinator
159
208
c := coordinator .New (
@@ -163,7 +212,27 @@ func TestUpgradeHandlerNewVersion(t *testing.T) {
163
212
agentInfo ,
164
213
component.RuntimeSpecs {},
165
214
nil ,
166
- & mockUpgradeManager {msgChan : msgChan , successChan : completedChan , errChan : errorChan },
215
+ & mockUpgradeManager {
216
+ UpgradeFn : func (
217
+ ctx context.Context ,
218
+ version string ,
219
+ sourceURI string ,
220
+ action * fleetapi.ActionUpgrade ,
221
+ details * details.Details ,
222
+ skipVerifyOverride bool ,
223
+ skipDefaultPgp bool ,
224
+ pgpBytes ... string ) (reexec.ShutdownCallbackFn , error ) {
225
+
226
+ defer func () {
227
+ upgradeCalledChan <- version
228
+ }()
229
+ if version == "8.2.0" {
230
+ return nil , errors .New ("upgrade to 8.2.0 will always fail" )
231
+ }
232
+
233
+ return nil , nil
234
+ },
235
+ },
167
236
nil , nil , nil , nil , nil , false )
168
237
//nolint:errcheck // We don't need the termination state of the Coordinator
169
238
go c .Run (ctx )
@@ -175,18 +244,24 @@ func TestUpgradeHandlerNewVersion(t *testing.T) {
175
244
Version : "8.5.0" , SourceURI : "http://localhost" }}
176
245
ack := noopacker .New ()
177
246
247
+ checkMsg := func (c <- chan string , expected , errMsg string ) {
248
+ t .Helper ()
249
+ // Make sure this test does not dead lock or wait for too long
250
+ // For some reason < 1s sometimes makes the test fail.
251
+ select {
252
+ case <- time .Tick (1300 * time .Millisecond ):
253
+ t .Fatal ("timed out waiting for Upgrade to return" )
254
+ case msg := <- c :
255
+ require .Equal (t , expected , msg , errMsg )
256
+ }
257
+ }
258
+
178
259
// Send both upgrade actions, a1 will error before a2 succeeds
179
260
err1 := u .Handle (ctx , & a1 , ack )
180
261
require .NoError (t , err1 )
262
+ checkMsg (upgradeCalledChan , "8.2.0" , "first call must be with version 8.2.0" )
263
+
181
264
err2 := u .Handle (ctx , & a2 , ack )
182
265
require .NoError (t , err2 )
183
-
184
- // Send an error so the first action is "cancelled"
185
- errorChan <- errors .New ("cancelled 8.2.0" )
186
- // Wait for the mockUpgradeHandler to receive and "process" the error
187
- require .Equal (t , "cancelled 8.2.0" , <- msgChan , "mockUpgradeHandler.Upgrade did not receive the expected error" )
188
-
189
- // Send a success so the second action succeeds
190
- completedChan <- "completed 8.5.0"
191
- require .Equal (t , "completed 8.5.0" , <- msgChan , "mockUpgradeHandler.Upgrade did not receive the success signal" )
266
+ checkMsg (upgradeCalledChan , "8.5.0" , "second call to Upgrade must be with version 8.5.0" )
192
267
}
0 commit comments