6
6
import json
7
7
import logging
8
8
from pathlib import Path
9
+ import secrets
9
10
from typing import TYPE_CHECKING , Final
10
11
from urllib .parse import unquote , urlparse
11
12
@@ -37,9 +38,12 @@ class DeviceSoftwareVersionModel: # pylint: disable=C0103
37
38
38
39
39
40
@dataclass
40
- class UpdateFile : # pylint: disable=C0103
41
+ class OtaProviderImageList : # pylint: disable=C0103
41
42
"""Update File for OTA Provider JSON descriptor file."""
42
43
44
+ otaProviderDiscriminator : int
45
+ otaProviderPasscode : int
46
+ otaProviderNodeId : int | None
43
47
deviceSoftwareVersionModel : list [DeviceSoftwareVersionModel ]
44
48
45
49
@@ -50,23 +54,103 @@ class ExternalOtaProvider:
50
54
for devices.
51
55
"""
52
56
53
- def __init__ (self ) -> None :
57
+ def __init__ (self , ota_provider_dir : Path ) -> None :
54
58
"""Initialize the OTA provider."""
59
+ self ._ota_provider_dir : Path = ota_provider_dir
60
+ self ._ota_provider_image_list_file : Path = ota_provider_dir / "updates.json"
61
+ self ._ota_provider_image_list : OtaProviderImageList | None = None
55
62
self ._ota_provider_proc : Process | None = None
56
63
self ._ota_provider_task : asyncio .Task | None = None
57
64
65
+ async def initialize (self ) -> None :
66
+ """Initialize OTA Provider."""
67
+
68
+ loop = asyncio .get_event_loop ()
69
+
70
+ # Take existence of image list file as indicator if we need to initialize the
71
+ # OTA Provider.
72
+ if not await loop .run_in_executor (
73
+ None , self ._ota_provider_image_list_file .exists
74
+ ):
75
+ await loop .run_in_executor (
76
+ None , functools .partial (DEFAULT_UPDATES_PATH .mkdir , exist_ok = True )
77
+ )
78
+
79
+ # Initialize with random data. Node ID will get written once paired by
80
+ # device controller.
81
+ self ._ota_provider_image_list = OtaProviderImageList (
82
+ otaProviderDiscriminator = secrets .randbelow (2 ** 12 ),
83
+ otaProviderPasscode = secrets .randbelow (2 ** 21 ),
84
+ otaProviderNodeId = None ,
85
+ deviceSoftwareVersionModel = [],
86
+ )
87
+ else :
88
+
89
+ def _read_update_json (
90
+ update_json_path : Path ,
91
+ ) -> None | OtaProviderImageList :
92
+ with open (update_json_path , "r" ) as json_file :
93
+ data = json .load (json_file )
94
+ return dataclass_from_dict (OtaProviderImageList , data )
95
+
96
+ self ._ota_provider_image_list = await loop .run_in_executor (
97
+ None , _read_update_json , self ._ota_provider_image_list_file
98
+ )
99
+
100
+ def _get_ota_provider_image_list (self ) -> OtaProviderImageList :
101
+ if self ._ota_provider_image_list is None :
102
+ raise RuntimeError ("OTA provider image list not initialized." )
103
+ return self ._ota_provider_image_list
104
+
105
+ def get_node_id (self ) -> int | None :
106
+ """Get Node ID of the OTA Provider App."""
107
+
108
+ return self ._get_ota_provider_image_list ().otaProviderNodeId
109
+
110
+ def get_descriminator (self ) -> int :
111
+ """Return OTA Provider App discriminator."""
112
+
113
+ return self ._get_ota_provider_image_list ().otaProviderDiscriminator
114
+
115
+ def get_passcode (self ) -> int :
116
+ """Return OTA Provider App passcode."""
117
+
118
+ return self ._get_ota_provider_image_list ().otaProviderPasscode
119
+
120
+ def set_node_id (self , node_id : int ) -> None :
121
+ """Set Node ID of the OTA Provider App."""
122
+
123
+ self ._get_ota_provider_image_list ().otaProviderNodeId = node_id
124
+
58
125
async def _start_ota_provider (self ) -> None :
59
- # TODO: Randomize discriminator
126
+ def _write_ota_provider_image_list_json (
127
+ ota_provider_image_list_file : Path ,
128
+ ota_provider_image_list : OtaProviderImageList ,
129
+ ) -> None :
130
+ update_file_dict = asdict (ota_provider_image_list )
131
+ with open (ota_provider_image_list_file , "w" ) as json_file :
132
+ json .dump (update_file_dict , json_file , indent = 4 )
133
+
134
+ loop = asyncio .get_running_loop ()
135
+ await loop .run_in_executor (
136
+ None ,
137
+ _write_ota_provider_image_list_json ,
138
+ self ._ota_provider_image_list_file ,
139
+ self ._get_ota_provider_image_list (),
140
+ )
141
+
60
142
ota_provider_cmd = [
61
143
"chip-ota-provider-app" ,
62
144
"--discriminator" ,
63
- "22" ,
145
+ str (self ._get_ota_provider_image_list ().otaProviderDiscriminator ),
146
+ "--passcode" ,
147
+ str (self ._get_ota_provider_image_list ().otaProviderPasscode ),
64
148
"--secured-device-port" ,
65
149
"5565" ,
66
150
"--KVS" ,
67
- "/data/chip_kvs_provider" ,
151
+ str ( self . _ota_provider_dir / "chip_kvs_ota_provider" ) ,
68
152
"--otaImageList" ,
69
- str (DEFAULT_UPDATES_PATH / "updates.json" ),
153
+ str (self . _ota_provider_image_list_file ),
70
154
]
71
155
72
156
LOGGER .info ("Starting OTA Provider" )
@@ -80,40 +164,41 @@ def start(self) -> None:
80
164
loop = asyncio .get_event_loop ()
81
165
self ._ota_provider_task = loop .create_task (self ._start_ota_provider ())
82
166
167
+ async def reset (self ) -> None :
168
+ """Reset the OTA Provider App state."""
169
+
170
+ def _remove_update_data (ota_provider_dir : Path ) -> None :
171
+ for path in ota_provider_dir .iterdir ():
172
+ if not path .is_dir ():
173
+ path .unlink ()
174
+
175
+ loop = asyncio .get_event_loop ()
176
+ await loop .run_in_executor (None , _remove_update_data , self ._ota_provider_dir )
177
+
178
+ await self .initialize ()
179
+
83
180
async def stop (self ) -> None :
84
181
"""Stop the OTA Provider."""
85
182
if self ._ota_provider_proc :
86
183
LOGGER .info ("Terminating OTA Provider" )
87
- self ._ota_provider_proc .terminate ()
184
+ loop = asyncio .get_event_loop ()
185
+ try :
186
+ await loop .run_in_executor (None , self ._ota_provider_proc .terminate )
187
+ except ProcessLookupError as ex :
188
+ LOGGER .warning ("Stopping OTA Provider failed with error:" , exc_info = ex )
88
189
if self ._ota_provider_task :
89
190
await self ._ota_provider_task
90
191
91
192
async def add_update (self , update_desc : dict , ota_file : Path ) -> None :
92
193
"""Add update to the OTA provider."""
93
194
94
- update_json_path = DEFAULT_UPDATES_PATH / "updates.json"
95
-
96
- def _read_update_json (update_json_path : Path ) -> None | UpdateFile :
97
- if not update_json_path .exists ():
98
- return None
99
-
100
- with open (update_json_path , "r" ) as json_file :
101
- data = json .load (json_file )
102
- return dataclass_from_dict (UpdateFile , data )
103
-
104
- loop = asyncio .get_running_loop ()
105
- update_file = await loop .run_in_executor (
106
- None , _read_update_json , update_json_path
107
- )
108
-
109
- if not update_file :
110
- update_file = UpdateFile (deviceSoftwareVersionModel = [])
111
-
112
195
local_ota_url = str (ota_file )
113
- for i , device_software in enumerate (update_file .deviceSoftwareVersionModel ):
196
+ for i , device_software in enumerate (
197
+ self ._get_ota_provider_image_list ().deviceSoftwareVersionModel
198
+ ):
114
199
if device_software .otaURL == local_ota_url :
115
200
LOGGER .debug ("Device software entry exists already, replacing!" )
116
- del update_file .deviceSoftwareVersionModel [i ]
201
+ del self . _get_ota_provider_image_list () .deviceSoftwareVersionModel [i ]
117
202
118
203
# Convert to OTA Requestor descriptor file
119
204
new_device_software = DeviceSoftwareVersionModel (
@@ -127,18 +212,8 @@ def _read_update_json(update_json_path: Path) -> None | UpdateFile:
127
212
maxApplicableSoftwareVersion = update_desc ["maxApplicableSoftwareVersion" ],
128
213
otaURL = local_ota_url ,
129
214
)
130
- update_file .deviceSoftwareVersionModel .append (new_device_software )
131
-
132
- def _write_update_json (update_json_path : Path , update_file : UpdateFile ) -> None :
133
- update_file_dict = asdict (update_file )
134
- with open (update_json_path , "w" ) as json_file :
135
- json .dump (update_file_dict , json_file , indent = 4 )
136
-
137
- await loop .run_in_executor (
138
- None ,
139
- _write_update_json ,
140
- update_json_path ,
141
- update_file ,
215
+ self ._get_ota_provider_image_list ().deviceSoftwareVersionModel .append (
216
+ new_device_software
142
217
)
143
218
144
219
async def download_update (self , update_desc : dict ) -> None :
0 commit comments