Skip to content

Commit 4e75b77

Browse files
committed
Implement update using OTA Provider app
Use the OTA Provider example app to implement a OTA provider. The example app supports a JSON update descriptor file to manage update metadata. Tested with the OTA Requestor app.
1 parent a90afb7 commit 4e75b77

File tree

3 files changed

+111
-30
lines changed

3 files changed

+111
-30
lines changed

matter_server/server/device_controller.py

+42-10
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,9 @@ async def stop(self) -> None:
220220
for sub in self._subscriptions.values():
221221
await self._call_sdk(sub.Shutdown)
222222
self._subscriptions = {}
223+
# shutdown the OTA Provider
224+
if self._ota_provider:
225+
await self._ota_provider.stop()
223226
# shutdown the sdk device controller
224227
await self._call_sdk(self.chip_controller.Shutdown)
225228
LOGGER.debug("Stopped.")
@@ -921,6 +924,13 @@ async def update_node(self, node_id: int) -> dict | None:
921924
node_logger = LOGGER.getChild(f"node_{node_id}")
922925
node = self._nodes[node_id]
923926

927+
if self.chip_controller is None:
928+
raise RuntimeError("Device Controller not initialized.")
929+
930+
if not self._ota_provider:
931+
LOGGER.warning("No OTA provider found, updates not possible.")
932+
return None
933+
924934
node_logger.debug("Check for updates.")
925935
vid = cast(int, node.attributes.get(BASIC_INFORMATION_VENDOR_ID))
926936
pid = cast(int, node.attributes.get(BASIC_INFORMATION_PRODUCT_ID))
@@ -932,17 +942,39 @@ async def update_node(self, node_id: int) -> dict | None:
932942
)
933943

934944
update = await check_updates(node_id, vid, pid, software_version)
935-
if update and "otaUrl" in update and len(update["otaUrl"]) > 0:
936-
node_logger.info(
937-
"New software update found: %s (current %s). Preparing updates...",
938-
update["softwareVersionString"],
939-
software_version_string,
940-
)
945+
if not update:
946+
node_logger.info("No new update found.")
947+
return None
948+
949+
if "otaUrl" not in update:
950+
node_logger.warning("Update found, but no OTA URL provided.")
951+
return None
941952

942-
# Add to OTA provider
943-
if not self._ota_provider:
944-
return None
945-
await self._ota_provider.download_update(update)
953+
node_logger.info(
954+
"New software update found: %s (current %s). Preparing updates...",
955+
update["softwareVersionString"],
956+
software_version_string,
957+
)
958+
959+
# Add to OTA provider
960+
await self._ota_provider.download_update(update)
961+
962+
self._ota_provider.start()
963+
964+
# Wait for OTA provider to be ready
965+
# TODO: Detect when OTA provider is ready
966+
await asyncio.sleep(2)
967+
968+
await self.chip_controller.SendCommand(
969+
nodeid=node_id,
970+
endpoint=0,
971+
payload=Clusters.OtaSoftwareUpdateRequestor.Commands.AnnounceOTAProvider(
972+
providerNodeID=32,
973+
vendorID=0, # TODO: Use Server Vendor ID
974+
announcementReason=Clusters.OtaSoftwareUpdateRequestor.Enums.AnnouncementReasonEnum.kUpdateAvailable,
975+
endpoint=0,
976+
),
977+
)
946978

947979
return update
948980

matter_server/server/ota/dcl.py

+4
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,13 @@ async def check_updates(
4141

4242
software_versions: list[int] = versions["modelVersions"]["softwareVersions"]
4343
latest_software_version = max(software_versions)
44+
45+
# Check if the software is indeed newer
4446
if latest_software_version <= current_software_version:
4547
return None
4648

49+
# TODO: Check minApplicableSoftwareVersion/maxApplicableSoftwareVersion
50+
4751
version: dict = await get_software_version(
4852
node_id, vid, pid, latest_software_version
4953
)

matter_server/server/ota/provider.py

+65-20
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,20 @@
22

33
import asyncio
44
from dataclasses import asdict, dataclass
5+
import functools
56
import json
67
import logging
78
from pathlib import Path
8-
from typing import Final
9+
from typing import TYPE_CHECKING, Final
910
from urllib.parse import unquote, urlparse
1011

1112
from aiohttp import ClientError, ClientSession
1213

1314
from matter_server.common.helpers.util import dataclass_from_dict
1415

16+
if TYPE_CHECKING:
17+
from asyncio.subprocess import Process
18+
1519
LOGGER = logging.getLogger(__name__)
1620

1721
DEFAULT_UPDATES_PATH: Final[Path] = Path("updates")
@@ -48,10 +52,42 @@ class ExternalOtaProvider:
4852

4953
def __init__(self) -> None:
5054
"""Initialize the OTA provider."""
55+
self._ota_provider_proc: Process | None = None
56+
self._ota_provider_task: asyncio.Task | None = None
57+
58+
async def _start_ota_provider(self) -> None:
59+
# TODO: Randomize discriminator
60+
ota_provider_cmd = [
61+
"chip-ota-provider-app",
62+
"--discriminator",
63+
"22",
64+
"--secured-device-port",
65+
"5565",
66+
"--KVS",
67+
"/data/chip_kvs_provider",
68+
"--otaImageList",
69+
str(DEFAULT_UPDATES_PATH / "updates.json"),
70+
]
71+
72+
LOGGER.info("Starting OTA Provider")
73+
self._ota_provider_proc = await asyncio.create_subprocess_exec(
74+
*ota_provider_cmd
75+
)
5176

5277
def start(self) -> None:
5378
"""Start the OTA Provider."""
5479

80+
loop = asyncio.get_event_loop()
81+
self._ota_provider_task = loop.create_task(self._start_ota_provider())
82+
83+
async def stop(self) -> None:
84+
"""Stop the OTA Provider."""
85+
if self._ota_provider_proc:
86+
LOGGER.info("Terminating OTA Provider")
87+
self._ota_provider_proc.terminate()
88+
if self._ota_provider_task:
89+
await self._ota_provider_task
90+
5591
async def add_update(self, update_desc: dict, ota_file: Path) -> None:
5692
"""Add update to the OTA provider."""
5793

@@ -73,24 +109,25 @@ def _read_update_json(update_json_path: Path) -> None | UpdateFile:
73109
if not update_file:
74110
update_file = UpdateFile(deviceSoftwareVersionModel=[])
75111

112+
local_ota_url = str(ota_file)
113+
for i, device_software in enumerate(update_file.deviceSoftwareVersionModel):
114+
if device_software.otaURL == local_ota_url:
115+
LOGGER.debug("Device software entry exists already, replacing!")
116+
del update_file.deviceSoftwareVersionModel[i]
117+
76118
# Convert to OTA Requestor descriptor file
77-
update_file.deviceSoftwareVersionModel.append(
78-
DeviceSoftwareVersionModel(
79-
vendorId=update_desc["vid"],
80-
productId=update_desc["pid"],
81-
softwareVersion=update_desc["softwareVersion"],
82-
softwareVersionString=update_desc["softwareVersionString"],
83-
cDVersionNumber=update_desc["cdVersionNumber"],
84-
softwareVersionValid=update_desc["softwareVersionValid"],
85-
minApplicableSoftwareVersion=update_desc[
86-
"minApplicableSoftwareVersion"
87-
],
88-
maxApplicableSoftwareVersion=update_desc[
89-
"maxApplicableSoftwareVersion"
90-
],
91-
otaURL=str(ota_file),
92-
)
119+
new_device_software = DeviceSoftwareVersionModel(
120+
vendorId=update_desc["vid"],
121+
productId=update_desc["pid"],
122+
softwareVersion=update_desc["softwareVersion"],
123+
softwareVersionString=update_desc["softwareVersionString"],
124+
cDVersionNumber=update_desc["cdVersionNumber"],
125+
softwareVersionValid=update_desc["softwareVersionValid"],
126+
minApplicableSoftwareVersion=update_desc["minApplicableSoftwareVersion"],
127+
maxApplicableSoftwareVersion=update_desc["maxApplicableSoftwareVersion"],
128+
otaURL=local_ota_url,
93129
)
130+
update_file.deviceSoftwareVersionModel.append(new_device_software)
94131

95132
def _write_update_json(update_json_path: Path, update_file: UpdateFile) -> None:
96133
update_file_dict = asdict(update_file)
@@ -112,9 +149,14 @@ async def download_update(self, update_desc: dict) -> None:
112149
file_name = unquote(Path(parsed_url.path).name)
113150

114151
loop = asyncio.get_running_loop()
115-
await loop.run_in_executor(None, DEFAULT_UPDATES_PATH.mkdir)
152+
await loop.run_in_executor(
153+
None, functools.partial(DEFAULT_UPDATES_PATH.mkdir, exists_ok=True)
154+
)
116155

117156
file_path = DEFAULT_UPDATES_PATH / file_name
157+
if await loop.run_in_executor(None, file_path.exists):
158+
LOGGER.info("File '%s' exists already, skipping download.", file_name)
159+
return
118160

119161
try:
120162
async with ClientSession(raise_for_status=True) as session:
@@ -123,10 +165,13 @@ async def download_update(self, update_desc: dict) -> None:
123165
async with session.get(url) as response:
124166
with file_path.open("wb") as f:
125167
while True:
126-
chunk = await response.content.read(1024)
168+
chunk = await response.content.read(4048)
127169
if not chunk:
128170
break
129-
f.write(chunk)
171+
await loop.run_in_executor(None, f.write, chunk)
172+
173+
# TODO: Check against otaChecksum
174+
130175
LOGGER.info(
131176
"File '%s' downloaded to '%s'", file_name, DEFAULT_UPDATES_PATH
132177
)

0 commit comments

Comments
 (0)