|
| 1 | +"""Handling Matter OTA provider.""" |
| 2 | + |
| 3 | +import asyncio |
| 4 | +from dataclasses import asdict, dataclass |
| 5 | +import json |
| 6 | +import logging |
| 7 | +from pathlib import Path |
| 8 | +from typing import Final |
| 9 | +from urllib.parse import unquote, urlparse |
| 10 | + |
| 11 | +from aiohttp import ClientError, ClientSession |
| 12 | + |
| 13 | +from matter_server.common.helpers.util import dataclass_from_dict |
| 14 | + |
| 15 | +LOGGER = logging.getLogger(__name__) |
| 16 | + |
| 17 | +DEFAULT_UPDATES_PATH: Final[Path] = Path("updates") |
| 18 | + |
| 19 | + |
| 20 | +@dataclass |
| 21 | +class DeviceSoftwareVersionModel: # pylint: disable=C0103 |
| 22 | + """Device Software Version Model for OTA Provider JSON descriptor file.""" |
| 23 | + |
| 24 | + vendorId: int |
| 25 | + productId: int |
| 26 | + softwareVersion: int |
| 27 | + softwareVersionString: str |
| 28 | + cDVersionNumber: int |
| 29 | + softwareVersionValid: bool |
| 30 | + minApplicableSoftwareVersion: int |
| 31 | + maxApplicableSoftwareVersion: int |
| 32 | + otaURL: str |
| 33 | + |
| 34 | + |
| 35 | +@dataclass |
| 36 | +class UpdateFile: # pylint: disable=C0103 |
| 37 | + """Update File for OTA Provider JSON descriptor file.""" |
| 38 | + |
| 39 | + deviceSoftwareVersionModel: list[DeviceSoftwareVersionModel] |
| 40 | + |
| 41 | + |
| 42 | +class ExternalOtaProvider: |
| 43 | + """Class handling Matter OTA Provider. |
| 44 | +
|
| 45 | + The OTA Provider class implements a Matter OTA (over-the-air) update provider |
| 46 | + for devices. |
| 47 | + """ |
| 48 | + |
| 49 | + def __init__(self) -> None: |
| 50 | + """Initialize the OTA provider.""" |
| 51 | + |
| 52 | + def start(self) -> None: |
| 53 | + """Start the OTA Provider.""" |
| 54 | + |
| 55 | + async def add_update(self, update_desc: dict, ota_file: Path) -> None: |
| 56 | + """Add update to the OTA provider.""" |
| 57 | + |
| 58 | + update_json_path = DEFAULT_UPDATES_PATH / "updates.json" |
| 59 | + |
| 60 | + def _read_update_json(update_json_path: Path) -> None | UpdateFile: |
| 61 | + if not update_json_path.exists(): |
| 62 | + return None |
| 63 | + |
| 64 | + with open(update_json_path, "r") as json_file: |
| 65 | + data = json.load(json_file) |
| 66 | + return dataclass_from_dict(UpdateFile, data) |
| 67 | + |
| 68 | + loop = asyncio.get_running_loop() |
| 69 | + update_file = await loop.run_in_executor( |
| 70 | + None, _read_update_json, update_json_path |
| 71 | + ) |
| 72 | + |
| 73 | + if not update_file: |
| 74 | + update_file = UpdateFile(deviceSoftwareVersionModel=[]) |
| 75 | + |
| 76 | + # Convert to OTA Requestor descriptor file |
| 77 | + update_file.deviceSoftwareVersionModel.append( |
| 78 | + DeviceSoftwareVersionModel( |
| 79 | + vendorId=update_desc["vid"], |
| 80 | + productId=update_desc["pid"], |
| 81 | + softwareVersion=update_desc["softwareVersion"], |
| 82 | + softwareVersionString=update_desc["softwareVersionString"], |
| 83 | + cDVersionNumber=update_desc["cdVersionNumber"], |
| 84 | + softwareVersionValid=update_desc["softwareVersionValid"], |
| 85 | + minApplicableSoftwareVersion=update_desc[ |
| 86 | + "minApplicableSoftwareVersion" |
| 87 | + ], |
| 88 | + maxApplicableSoftwareVersion=update_desc[ |
| 89 | + "maxApplicableSoftwareVersion" |
| 90 | + ], |
| 91 | + otaURL=str(ota_file), |
| 92 | + ) |
| 93 | + ) |
| 94 | + |
| 95 | + def _write_update_json(update_json_path: Path, update_file: UpdateFile) -> None: |
| 96 | + update_file_dict = asdict(update_file) |
| 97 | + with open(update_json_path, "w") as json_file: |
| 98 | + json.dump(update_file_dict, json_file, indent=4) |
| 99 | + |
| 100 | + await loop.run_in_executor( |
| 101 | + None, |
| 102 | + _write_update_json, |
| 103 | + update_json_path, |
| 104 | + update_file, |
| 105 | + ) |
| 106 | + |
| 107 | + async def download_update(self, update_desc: dict) -> None: |
| 108 | + """Download update file from OTA Path and add it to the OTA provider.""" |
| 109 | + |
| 110 | + url = update_desc["otaUrl"] |
| 111 | + parsed_url = urlparse(url) |
| 112 | + file_name = unquote(Path(parsed_url.path).name) |
| 113 | + |
| 114 | + loop = asyncio.get_running_loop() |
| 115 | + await loop.run_in_executor(None, DEFAULT_UPDATES_PATH.mkdir) |
| 116 | + |
| 117 | + file_path = DEFAULT_UPDATES_PATH / file_name |
| 118 | + |
| 119 | + try: |
| 120 | + async with ClientSession(raise_for_status=True) as session: |
| 121 | + # fetch the paa certificates list |
| 122 | + logging.debug("Download update from f{url}.") |
| 123 | + async with session.get(url) as response: |
| 124 | + with file_path.open("wb") as f: |
| 125 | + while True: |
| 126 | + chunk = await response.content.read(1024) |
| 127 | + if not chunk: |
| 128 | + break |
| 129 | + f.write(chunk) |
| 130 | + LOGGER.info( |
| 131 | + "File '%s' downloaded to '%s'", file_name, DEFAULT_UPDATES_PATH |
| 132 | + ) |
| 133 | + |
| 134 | + except (ClientError, TimeoutError) as err: |
| 135 | + LOGGER.error( |
| 136 | + "Fetching software version failed: error %s", err, exc_info=err |
| 137 | + ) |
| 138 | + |
| 139 | + await self.add_update(update_desc, file_path) |
0 commit comments