2
2
3
3
import asyncio
4
4
from dataclasses import asdict , dataclass
5
+ import functools
5
6
import json
6
7
import logging
7
8
from pathlib import Path
8
- from typing import Final
9
+ from typing import TYPE_CHECKING , Final
9
10
from urllib .parse import unquote , urlparse
10
11
11
12
from aiohttp import ClientError , ClientSession
12
13
13
14
from matter_server .common .helpers .util import dataclass_from_dict
14
15
16
+ if TYPE_CHECKING :
17
+ from asyncio .subprocess import Process
18
+
15
19
LOGGER = logging .getLogger (__name__ )
16
20
17
21
DEFAULT_UPDATES_PATH : Final [Path ] = Path ("updates" )
@@ -48,10 +52,42 @@ class ExternalOtaProvider:
48
52
49
53
def __init__ (self ) -> None :
50
54
"""Initialize the OTA provider."""
55
+ self ._ota_provider_proc : Process | None = None
56
+ self ._ota_provider_task : asyncio .Task | None = None
57
+
58
+ async def _start_ota_provider (self ) -> None :
59
+ # TODO: Randomize discriminator
60
+ ota_provider_cmd = [
61
+ "chip-ota-provider-app" ,
62
+ "--discriminator" ,
63
+ "22" ,
64
+ "--secured-device-port" ,
65
+ "5565" ,
66
+ "--KVS" ,
67
+ "/data/chip_kvs_provider" ,
68
+ "--otaImageList" ,
69
+ str (DEFAULT_UPDATES_PATH / "updates.json" ),
70
+ ]
71
+
72
+ LOGGER .info ("Starting OTA Provider" )
73
+ self ._ota_provider_proc = await asyncio .create_subprocess_exec (
74
+ * ota_provider_cmd
75
+ )
51
76
52
77
def start (self ) -> None :
53
78
"""Start the OTA Provider."""
54
79
80
+ loop = asyncio .get_event_loop ()
81
+ self ._ota_provider_task = loop .create_task (self ._start_ota_provider ())
82
+
83
+ async def stop (self ) -> None :
84
+ """Stop the OTA Provider."""
85
+ if self ._ota_provider_proc :
86
+ LOGGER .info ("Terminating OTA Provider" )
87
+ self ._ota_provider_proc .terminate ()
88
+ if self ._ota_provider_task :
89
+ await self ._ota_provider_task
90
+
55
91
async def add_update (self , update_desc : dict , ota_file : Path ) -> None :
56
92
"""Add update to the OTA provider."""
57
93
@@ -73,24 +109,25 @@ def _read_update_json(update_json_path: Path) -> None | UpdateFile:
73
109
if not update_file :
74
110
update_file = UpdateFile (deviceSoftwareVersionModel = [])
75
111
112
+ local_ota_url = str (ota_file )
113
+ for i , device_software in enumerate (update_file .deviceSoftwareVersionModel ):
114
+ if device_software .otaURL == local_ota_url :
115
+ LOGGER .debug ("Device software entry exists already, replacing!" )
116
+ del update_file .deviceSoftwareVersionModel [i ]
117
+
76
118
# Convert to OTA Requestor descriptor file
77
- update_file .deviceSoftwareVersionModel .append (
78
- DeviceSoftwareVersionModel (
79
- vendorId = update_desc ["vid" ],
80
- productId = update_desc ["pid" ],
81
- softwareVersion = update_desc ["softwareVersion" ],
82
- softwareVersionString = update_desc ["softwareVersionString" ],
83
- cDVersionNumber = update_desc ["cdVersionNumber" ],
84
- softwareVersionValid = update_desc ["softwareVersionValid" ],
85
- minApplicableSoftwareVersion = update_desc [
86
- "minApplicableSoftwareVersion"
87
- ],
88
- maxApplicableSoftwareVersion = update_desc [
89
- "maxApplicableSoftwareVersion"
90
- ],
91
- otaURL = str (ota_file ),
92
- )
119
+ new_device_software = DeviceSoftwareVersionModel (
120
+ vendorId = update_desc ["vid" ],
121
+ productId = update_desc ["pid" ],
122
+ softwareVersion = update_desc ["softwareVersion" ],
123
+ softwareVersionString = update_desc ["softwareVersionString" ],
124
+ cDVersionNumber = update_desc ["cdVersionNumber" ],
125
+ softwareVersionValid = update_desc ["softwareVersionValid" ],
126
+ minApplicableSoftwareVersion = update_desc ["minApplicableSoftwareVersion" ],
127
+ maxApplicableSoftwareVersion = update_desc ["maxApplicableSoftwareVersion" ],
128
+ otaURL = local_ota_url ,
93
129
)
130
+ update_file .deviceSoftwareVersionModel .append (new_device_software )
94
131
95
132
def _write_update_json (update_json_path : Path , update_file : UpdateFile ) -> None :
96
133
update_file_dict = asdict (update_file )
@@ -112,9 +149,14 @@ async def download_update(self, update_desc: dict) -> None:
112
149
file_name = unquote (Path (parsed_url .path ).name )
113
150
114
151
loop = asyncio .get_running_loop ()
115
- await loop .run_in_executor (None , DEFAULT_UPDATES_PATH .mkdir )
152
+ await loop .run_in_executor (
153
+ None , functools .partial (DEFAULT_UPDATES_PATH .mkdir , exists_ok = True )
154
+ )
116
155
117
156
file_path = DEFAULT_UPDATES_PATH / file_name
157
+ if await loop .run_in_executor (None , file_path .exists ):
158
+ LOGGER .info ("File '%s' exists already, skipping download." , file_name )
159
+ return
118
160
119
161
try :
120
162
async with ClientSession (raise_for_status = True ) as session :
@@ -123,10 +165,13 @@ async def download_update(self, update_desc: dict) -> None:
123
165
async with session .get (url ) as response :
124
166
with file_path .open ("wb" ) as f :
125
167
while True :
126
- chunk = await response .content .read (1024 )
168
+ chunk = await response .content .read (4048 )
127
169
if not chunk :
128
170
break
129
- f .write (chunk )
171
+ await loop .run_in_executor (None , f .write , chunk )
172
+
173
+ # TODO: Check against otaChecksum
174
+
130
175
LOGGER .info (
131
176
"File '%s' downloaded to '%s'" , file_name , DEFAULT_UPDATES_PATH
132
177
)
0 commit comments