@@ -65,9 +65,12 @@ class ExternalOtaProvider:
65
65
66
66
ENDPOINT_ID : Final [int ] = 0
67
67
68
- def __init__ (self , vendor_id : int , ota_provider_dir : Path ) -> None :
68
+ def __init__ (
69
+ self , vendor_id : int , ota_provider_base_dir : Path , ota_provider_dir : Path
70
+ ) -> None :
69
71
"""Initialize the OTA provider."""
70
72
self ._vendor_id : int = vendor_id
73
+ self ._ota_provider_base_dir : Path = ota_provider_base_dir
71
74
self ._ota_provider_dir : Path = ota_provider_dir
72
75
self ._ota_file_path : Path | None = None
73
76
self ._ota_provider_proc : Process | None = None
@@ -261,10 +264,11 @@ async def stop(self) -> None:
261
264
self ._ota_provider_proc = None
262
265
self ._ota_provider_task = None
263
266
264
- async def download_update (self , update_desc : dict ) -> None :
265
- """Download update file from OTA Path and add it to the OTA provider."""
267
+ async def _download_update (
268
+ self , url : str , checksum_alg : hashlib ._Hash | None
269
+ ) -> Path :
270
+ """Download update file from OTA URL."""
266
271
267
- url = update_desc ["otaUrl" ]
268
272
parsed_url = urlparse (url )
269
273
file_name = unquote (Path (parsed_url .path ).name )
270
274
@@ -273,20 +277,6 @@ async def download_update(self, update_desc: dict) -> None:
273
277
file_path = self ._ota_provider_dir / file_name
274
278
275
279
try :
276
- checksum_alg = None
277
- if (
278
- "otaChecksum" in update_desc
279
- and "otaChecksumType" in update_desc
280
- and update_desc ["otaChecksumType" ] in CHECHKSUM_TYPES
281
- ):
282
- checksum_alg = hashlib .new (
283
- CHECHKSUM_TYPES [update_desc ["otaChecksumType" ]]
284
- )
285
- else :
286
- LOGGER .warning (
287
- "No OTA checksum type or not supported, OTA will not be checked."
288
- )
289
-
290
280
async with ClientSession (raise_for_status = True ) as session :
291
281
# fetch the paa certificates list
292
282
LOGGER .debug ("Download update from '%s'." , url )
@@ -300,20 +290,6 @@ async def download_update(self, update_desc: dict) -> None:
300
290
if checksum_alg :
301
291
checksum_alg .update (chunk )
302
292
303
- # Download finished, check checksum if necessary
304
- if checksum_alg :
305
- checksum = b64encode (checksum_alg .digest ()).decode ("ascii" )
306
- checksum_expected = update_desc ["otaChecksum" ].strip ()
307
- if checksum != checksum_expected :
308
- LOGGER .error (
309
- "Checksum mismatch for file '%s', expected: '%s', got: '%s'" ,
310
- file_name ,
311
- checksum_expected ,
312
- checksum ,
313
- )
314
- await loop .run_in_executor (None , file_path .unlink )
315
- raise UpdateError ("Checksum mismatch!" )
316
-
317
293
LOGGER .info (
318
294
"Update file '%s' downloaded to '%s'" ,
319
295
file_name ,
@@ -326,6 +302,55 @@ async def download_update(self, update_desc: dict) -> None:
326
302
)
327
303
raise UpdateError ("Fetching software version failed" ) from err
328
304
305
+ return file_path
306
+
307
+ async def fetch_update (self , update_desc : dict ) -> None :
308
+ """Fetch update file from OTA URL."""
309
+ url = update_desc ["otaUrl" ]
310
+ parsed_url = urlparse (url )
311
+ file_name = unquote (Path (parsed_url .path ).name )
312
+
313
+ loop = asyncio .get_running_loop ()
314
+
315
+ checksum_alg = None
316
+ if (
317
+ "otaChecksum" in update_desc
318
+ and "otaChecksumType" in update_desc
319
+ and update_desc ["otaChecksumType" ] in CHECHKSUM_TYPES
320
+ ):
321
+ checksum_alg = hashlib .new (CHECHKSUM_TYPES [update_desc ["otaChecksumType" ]])
322
+ else :
323
+ LOGGER .warning (
324
+ "No OTA checksum type or not supported, OTA will not be checked."
325
+ )
326
+
327
+ if parsed_url .scheme in ["http" , "https" ]:
328
+ file_path = await self ._download_update (url , checksum_alg )
329
+ elif parsed_url .scheme in ["file" ]:
330
+ file_path = self ._ota_provider_base_dir / Path (parsed_url .path [1 :])
331
+ if not file_path .exists ():
332
+ logging .warning ("Local update file not found: %s" , file_path )
333
+ raise UpdateError ("Local update file not found" )
334
+ if checksum_alg :
335
+ checksum_alg .update (
336
+ await loop .run_in_executor (None , file_path .read_bytes )
337
+ )
338
+ else :
339
+ raise UpdateError ("Unsupported OTA URL scheme" )
340
+
341
+ # Download finished, check checksum if necessary
342
+ if checksum_alg :
343
+ checksum_expected = update_desc ["otaChecksum" ].strip ()
344
+ checksum = b64encode (checksum_alg .digest ()).decode ("ascii" )
345
+ if checksum != checksum_expected :
346
+ LOGGER .error (
347
+ "Checksum mismatch for file '%s', expected: '%s', got: '%s'" ,
348
+ file_name ,
349
+ checksum_expected ,
350
+ checksum ,
351
+ )
352
+ raise UpdateError ("Checksum mismatch!" )
353
+
329
354
self ._ota_file_path = file_path
330
355
331
356
async def check_update_state (
0 commit comments