Skip to content

Commit 2bab7e3

Browse files
committed
Implement update using OTA Provider app
Use the OTA Provider example app to implement a OTA provider. The example app supports a JSON update descriptor file to manage update metadata. Tested with the OTA Requestor app.
1 parent 17f9f54 commit 2bab7e3

File tree

2 files changed

+107
-28
lines changed

2 files changed

+107
-28
lines changed

matter_server/server/device_controller.py

+42-8
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,9 @@ async def stop(self) -> None:
221221

222222
# shutdown the sdk device controller
223223
await self._chip_device_controller.shutdown()
224+
# shutdown the OTA Provider
225+
if self._ota_provider:
226+
await self._ota_provider.stop()
224227
LOGGER.debug("Stopped.")
225228

226229
@property
@@ -903,6 +906,13 @@ async def update_node(self, node_id: int) -> dict | None:
903906
node_logger = LOGGER.getChild(f"node_{node_id}")
904907
node = self._nodes[node_id]
905908

909+
if self.chip_controller is None:
910+
raise RuntimeError("Device Controller not initialized.")
911+
912+
if not self._ota_provider:
913+
LOGGER.warning("No OTA provider found, updates not possible.")
914+
return None
915+
906916
node_logger.debug("Check for updates.")
907917
vid = cast(int, node.attributes.get(BASIC_INFORMATION_VENDOR_ID_ATTRIBUTE_PATH))
908918
pid = cast(
@@ -916,15 +926,39 @@ async def update_node(self, node_id: int) -> dict | None:
916926
)
917927

918928
update = await check_updates(node_id, vid, pid, software_version)
919-
if update and "otaUrl" in update and len(update["otaUrl"]) > 0:
920-
node_logger.info(
921-
"New software update found: %s (current %s). Preparing updates...",
922-
update["softwareVersionString"],
923-
software_version_string,
924-
)
929+
if not update:
930+
node_logger.info("No new update found.")
931+
return None
932+
933+
if "otaUrl" not in update:
934+
node_logger.warning("Update found, but no OTA URL provided.")
935+
return None
925936

926-
# Add to OTA provider
927-
await self._ota_provider.download_update(update)
937+
node_logger.info(
938+
"New software update found: %s (current %s). Preparing updates...",
939+
update["softwareVersionString"],
940+
software_version_string,
941+
)
942+
943+
# Add to OTA provider
944+
await self._ota_provider.download_update(update)
945+
946+
self._ota_provider.start()
947+
948+
# Wait for OTA provider to be ready
949+
# TODO: Detect when OTA provider is ready
950+
await asyncio.sleep(2)
951+
952+
await self.chip_controller.SendCommand(
953+
nodeid=node_id,
954+
endpoint=0,
955+
payload=Clusters.OtaSoftwareUpdateRequestor.Commands.AnnounceOTAProvider(
956+
providerNodeID=32,
957+
vendorID=0, # TODO: Use Server Vendor ID
958+
announcementReason=Clusters.OtaSoftwareUpdateRequestor.Enums.AnnouncementReasonEnum.kUpdateAvailable,
959+
endpoint=0,
960+
),
961+
)
928962

929963
return update
930964

matter_server/server/ota/provider.py

+65-20
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,20 @@
22

33
import asyncio
44
from dataclasses import asdict, dataclass
5+
import functools
56
import json
67
import logging
78
from pathlib import Path
8-
from typing import Final
9+
from typing import TYPE_CHECKING, Final
910
from urllib.parse import unquote, urlparse
1011

1112
from aiohttp import ClientError, ClientSession
1213

1314
from matter_server.common.helpers.util import dataclass_from_dict
1415

16+
if TYPE_CHECKING:
17+
from asyncio.subprocess import Process
18+
1519
LOGGER = logging.getLogger(__name__)
1620

1721
DEFAULT_UPDATES_PATH: Final[Path] = Path("updates")
@@ -48,10 +52,42 @@ class ExternalOtaProvider:
4852

4953
def __init__(self) -> None:
5054
"""Initialize the OTA provider."""
55+
self._ota_provider_proc: Process | None = None
56+
self._ota_provider_task: asyncio.Task | None = None
57+
58+
async def _start_ota_provider(self) -> None:
59+
# TODO: Randomize discriminator
60+
ota_provider_cmd = [
61+
"chip-ota-provider-app",
62+
"--discriminator",
63+
"22",
64+
"--secured-device-port",
65+
"5565",
66+
"--KVS",
67+
"/data/chip_kvs_provider",
68+
"--otaImageList",
69+
str(DEFAULT_UPDATES_PATH / "updates.json"),
70+
]
71+
72+
LOGGER.info("Starting OTA Provider")
73+
self._ota_provider_proc = await asyncio.create_subprocess_exec(
74+
*ota_provider_cmd
75+
)
5176

5277
def start(self) -> None:
5378
"""Start the OTA Provider."""
5479

80+
loop = asyncio.get_event_loop()
81+
self._ota_provider_task = loop.create_task(self._start_ota_provider())
82+
83+
async def stop(self) -> None:
84+
"""Stop the OTA Provider."""
85+
if self._ota_provider_proc:
86+
LOGGER.info("Terminating OTA Provider")
87+
self._ota_provider_proc.terminate()
88+
if self._ota_provider_task:
89+
await self._ota_provider_task
90+
5591
async def add_update(self, update_desc: dict, ota_file: Path) -> None:
5692
"""Add update to the OTA provider."""
5793

@@ -73,24 +109,25 @@ def _read_update_json(update_json_path: Path) -> None | UpdateFile:
73109
if not update_file:
74110
update_file = UpdateFile(deviceSoftwareVersionModel=[])
75111

112+
local_ota_url = str(ota_file)
113+
for i, device_software in enumerate(update_file.deviceSoftwareVersionModel):
114+
if device_software.otaURL == local_ota_url:
115+
LOGGER.debug("Device software entry exists already, replacing!")
116+
del update_file.deviceSoftwareVersionModel[i]
117+
76118
# Convert to OTA Requestor descriptor file
77-
update_file.deviceSoftwareVersionModel.append(
78-
DeviceSoftwareVersionModel(
79-
vendorId=update_desc["vid"],
80-
productId=update_desc["pid"],
81-
softwareVersion=update_desc["softwareVersion"],
82-
softwareVersionString=update_desc["softwareVersionString"],
83-
cDVersionNumber=update_desc["cdVersionNumber"],
84-
softwareVersionValid=update_desc["softwareVersionValid"],
85-
minApplicableSoftwareVersion=update_desc[
86-
"minApplicableSoftwareVersion"
87-
],
88-
maxApplicableSoftwareVersion=update_desc[
89-
"maxApplicableSoftwareVersion"
90-
],
91-
otaURL=str(ota_file),
92-
)
119+
new_device_software = DeviceSoftwareVersionModel(
120+
vendorId=update_desc["vid"],
121+
productId=update_desc["pid"],
122+
softwareVersion=update_desc["softwareVersion"],
123+
softwareVersionString=update_desc["softwareVersionString"],
124+
cDVersionNumber=update_desc["cdVersionNumber"],
125+
softwareVersionValid=update_desc["softwareVersionValid"],
126+
minApplicableSoftwareVersion=update_desc["minApplicableSoftwareVersion"],
127+
maxApplicableSoftwareVersion=update_desc["maxApplicableSoftwareVersion"],
128+
otaURL=local_ota_url,
93129
)
130+
update_file.deviceSoftwareVersionModel.append(new_device_software)
94131

95132
def _write_update_json(update_json_path: Path, update_file: UpdateFile) -> None:
96133
update_file_dict = asdict(update_file)
@@ -112,9 +149,14 @@ async def download_update(self, update_desc: dict) -> None:
112149
file_name = unquote(Path(parsed_url.path).name)
113150

114151
loop = asyncio.get_running_loop()
115-
await loop.run_in_executor(None, DEFAULT_UPDATES_PATH.mkdir)
152+
await loop.run_in_executor(
153+
None, functools.partial(DEFAULT_UPDATES_PATH.mkdir, exists_ok=True)
154+
)
116155

117156
file_path = DEFAULT_UPDATES_PATH / file_name
157+
if await loop.run_in_executor(None, file_path.exists):
158+
LOGGER.info("File '%s' exists already, skipping download.", file_name)
159+
return
118160

119161
try:
120162
async with ClientSession(raise_for_status=True) as session:
@@ -123,10 +165,13 @@ async def download_update(self, update_desc: dict) -> None:
123165
async with session.get(url) as response:
124166
with file_path.open("wb") as f:
125167
while True:
126-
chunk = await response.content.read(1024)
168+
chunk = await response.content.read(4048)
127169
if not chunk:
128170
break
129-
f.write(chunk)
171+
await loop.run_in_executor(None, f.write, chunk)
172+
173+
# TODO: Check against otaChecksum
174+
130175
LOGGER.info(
131176
"File '%s' downloaded to '%s'", file_name, DEFAULT_UPDATES_PATH
132177
)

0 commit comments

Comments
 (0)