Skip to content

Commit fcf6229

Browse files
authored
Support local update file for OTA (#884)
1 parent 0af331d commit fcf6229

File tree

2 files changed

+61
-34
lines changed

2 files changed

+61
-34
lines changed

matter_server/server/device_controller.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -972,13 +972,15 @@ async def update_node(self, node_id: int, software_version: int | str) -> None:
972972

973973
# Add update to the OTA provider
974974
ota_provider = ExternalOtaProvider(
975-
self.server.vendor_id, self._ota_provider_dir / f"{node_id}"
975+
self.server.vendor_id,
976+
self._ota_provider_dir,
977+
self._ota_provider_dir / f"{node_id}",
976978
)
977979

978980
await ota_provider.initialize()
979981

980982
node_logger.info("Downloading update from '%s'", update["otaUrl"])
981-
await ota_provider.download_update(update)
983+
await ota_provider.fetch_update(update)
982984

983985
self._attribute_update_callbacks.setdefault(node_id, []).append(
984986
ota_provider.check_update_state

matter_server/server/ota/provider.py

+57-32
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,12 @@ class ExternalOtaProvider:
6565

6666
ENDPOINT_ID: Final[int] = 0
6767

68-
def __init__(self, vendor_id: int, ota_provider_dir: Path) -> None:
68+
def __init__(
69+
self, vendor_id: int, ota_provider_base_dir: Path, ota_provider_dir: Path
70+
) -> None:
6971
"""Initialize the OTA provider."""
7072
self._vendor_id: int = vendor_id
73+
self._ota_provider_base_dir: Path = ota_provider_base_dir
7174
self._ota_provider_dir: Path = ota_provider_dir
7275
self._ota_file_path: Path | None = None
7376
self._ota_provider_proc: Process | None = None
@@ -261,10 +264,11 @@ async def stop(self) -> None:
261264
self._ota_provider_proc = None
262265
self._ota_provider_task = None
263266

264-
async def download_update(self, update_desc: dict) -> None:
265-
"""Download update file from OTA Path and add it to the OTA provider."""
267+
async def _download_update(
268+
self, url: str, checksum_alg: hashlib._Hash | None
269+
) -> Path:
270+
"""Download update file from OTA URL."""
266271

267-
url = update_desc["otaUrl"]
268272
parsed_url = urlparse(url)
269273
file_name = unquote(Path(parsed_url.path).name)
270274

@@ -273,20 +277,6 @@ async def download_update(self, update_desc: dict) -> None:
273277
file_path = self._ota_provider_dir / file_name
274278

275279
try:
276-
checksum_alg = None
277-
if (
278-
"otaChecksum" in update_desc
279-
and "otaChecksumType" in update_desc
280-
and update_desc["otaChecksumType"] in CHECHKSUM_TYPES
281-
):
282-
checksum_alg = hashlib.new(
283-
CHECHKSUM_TYPES[update_desc["otaChecksumType"]]
284-
)
285-
else:
286-
LOGGER.warning(
287-
"No OTA checksum type or not supported, OTA will not be checked."
288-
)
289-
290280
async with ClientSession(raise_for_status=True) as session:
291281
# fetch the paa certificates list
292282
LOGGER.debug("Download update from '%s'.", url)
@@ -300,20 +290,6 @@ async def download_update(self, update_desc: dict) -> None:
300290
if checksum_alg:
301291
checksum_alg.update(chunk)
302292

303-
# Download finished, check checksum if necessary
304-
if checksum_alg:
305-
checksum = b64encode(checksum_alg.digest()).decode("ascii")
306-
checksum_expected = update_desc["otaChecksum"].strip()
307-
if checksum != checksum_expected:
308-
LOGGER.error(
309-
"Checksum mismatch for file '%s', expected: '%s', got: '%s'",
310-
file_name,
311-
checksum_expected,
312-
checksum,
313-
)
314-
await loop.run_in_executor(None, file_path.unlink)
315-
raise UpdateError("Checksum mismatch!")
316-
317293
LOGGER.info(
318294
"Update file '%s' downloaded to '%s'",
319295
file_name,
@@ -326,6 +302,55 @@ async def download_update(self, update_desc: dict) -> None:
326302
)
327303
raise UpdateError("Fetching software version failed") from err
328304

305+
return file_path
306+
307+
async def fetch_update(self, update_desc: dict) -> None:
308+
"""Fetch update file from OTA URL."""
309+
url = update_desc["otaUrl"]
310+
parsed_url = urlparse(url)
311+
file_name = unquote(Path(parsed_url.path).name)
312+
313+
loop = asyncio.get_running_loop()
314+
315+
checksum_alg = None
316+
if (
317+
"otaChecksum" in update_desc
318+
and "otaChecksumType" in update_desc
319+
and update_desc["otaChecksumType"] in CHECHKSUM_TYPES
320+
):
321+
checksum_alg = hashlib.new(CHECHKSUM_TYPES[update_desc["otaChecksumType"]])
322+
else:
323+
LOGGER.warning(
324+
"No OTA checksum type or not supported, OTA will not be checked."
325+
)
326+
327+
if parsed_url.scheme in ["http", "https"]:
328+
file_path = await self._download_update(url, checksum_alg)
329+
elif parsed_url.scheme in ["file"]:
330+
file_path = self._ota_provider_base_dir / Path(parsed_url.path[1:])
331+
if not file_path.exists():
332+
logging.warning("Local update file not found: %s", file_path)
333+
raise UpdateError("Local update file not found")
334+
if checksum_alg:
335+
checksum_alg.update(
336+
await loop.run_in_executor(None, file_path.read_bytes)
337+
)
338+
else:
339+
raise UpdateError("Unsupported OTA URL scheme")
340+
341+
# Download finished, check checksum if necessary
342+
if checksum_alg:
343+
checksum_expected = update_desc["otaChecksum"].strip()
344+
checksum = b64encode(checksum_alg.digest()).decode("ascii")
345+
if checksum != checksum_expected:
346+
LOGGER.error(
347+
"Checksum mismatch for file '%s', expected: '%s', got: '%s'",
348+
file_name,
349+
checksum_expected,
350+
checksum,
351+
)
352+
raise UpdateError("Checksum mismatch!")
353+
329354
self._ota_file_path = file_path
330355

331356
async def check_update_state(

0 commit comments

Comments
 (0)