Skip to content

Commit 745ff7d

Browse files
committed
Setup OTA Provider App automatically when necessary
Start and commission OTA Provider App when necessary. Use random discriminator and passcode. Store the Node ID of the OTA Provider App once setup for fast re-use.
1 parent 4e75b77 commit 745ff7d

File tree

5 files changed

+219
-45
lines changed

5 files changed

+219
-45
lines changed

matter_server/server/__main__.py

+7
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,12 @@
9898
default=None,
9999
help="Directory where PAA root certificates are stored.",
100100
)
101+
parser.add_argument(
102+
"--ota-provider-dir",
103+
type=str,
104+
default=None,
105+
help="Directory where OTA Provider stores software updates and configuration.",
106+
)
101107

102108
args = parser.parse_args()
103109

@@ -186,6 +192,7 @@ def main() -> None:
186192
args.listen_address,
187193
args.primary_interface,
188194
args.paa_root_cert_dir,
195+
args.ota_provider_dir,
189196
)
190197

191198
async def handle_stop(loop: asyncio.AbstractEventLoop) -> None:

matter_server/server/const.py

+2
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,5 @@
2020
.parent.resolve()
2121
.joinpath("credentials/development/paa-root-certs")
2222
)
23+
24+
DEFAULT_OTA_PROVIDER_DIR: Final[pathlib.Path] = pathlib.Path().cwd().joinpath("updates")

matter_server/server/device_controller.py

+83-4
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313
import time
1414
from typing import TYPE_CHECKING, Any, TypeVar, cast
1515

16-
from chip.clusters import Attribute, Objects as Clusters
16+
from chip.clusters import Attribute, Objects as Clusters, Types
1717
from chip.clusters.Attribute import ValueDecodeFailure
1818
from chip.clusters.ClusterObjects import ALL_ATTRIBUTES, ALL_CLUSTERS, Cluster
1919
from chip.discovery import DiscoveryType
2020
from chip.exceptions import ChipStackError
21+
from chip.interaction_model import Status
2122
from zeroconf import BadTypeInNameException, IPVersion, ServiceStateChange, Zeroconf
2223
from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf
2324

@@ -138,7 +139,7 @@ def __init__(
138139
self._node_lock: dict[int, asyncio.Lock] = {}
139140
self._ota_provider: ExternalOtaProvider | None = None
140141

141-
async def initialize(self, paa_root_cert_dir: Path) -> None:
142+
async def initialize(self, paa_root_cert_dir: Path, ota_provider_dir: Path) -> None:
142143
"""Async initialize of controller."""
143144
# (re)fetch all PAA certificates once at startup
144145
# NOTE: this must be done before initializing the controller
@@ -152,7 +153,8 @@ async def initialize(self, paa_root_cert_dir: Path) -> None:
152153
int, await self._call_sdk(self.chip_controller.GetCompressedFabricId)
153154
)
154155
self.fabric_id_hex = hex(self.compressed_fabric_id)[2:]
155-
self._ota_provider = ExternalOtaProvider()
156+
self._ota_provider = ExternalOtaProvider(ota_provider_dir)
157+
await self._ota_provider.initialize()
156158
LOGGER.debug("CHIP Device Controller Initialized")
157159

158160
async def start(self) -> None:
@@ -959,17 +961,94 @@ async def update_node(self, node_id: int) -> dict | None:
959961
# Add to OTA provider
960962
await self._ota_provider.download_update(update)
961963

964+
ota_provider_node_id = self._ota_provider.get_node_id()
965+
if ota_provider_node_id not in self._nodes:
966+
LOGGER.warning(
967+
"OTA Provider node id %d no longer exists! Resetting...",
968+
ota_provider_node_id,
969+
)
970+
await self._ota_provider.reset()
971+
ota_provider_node_id = None
972+
973+
# Make sure any previous instances get stopped
974+
await self._ota_provider.stop()
962975
self._ota_provider.start()
963976

964977
# Wait for OTA provider to be ready
965978
# TODO: Detect when OTA provider is ready
966979
await asyncio.sleep(2)
967980

981+
if not ota_provider_node_id:
982+
# The OTA Provider has not been commissioned yet, let's do it now.
983+
LOGGER.info("Commissioning the built-in OTA Provider App.")
984+
try:
985+
ota_provider_node = await self.commission_on_network(
986+
self._ota_provider.get_passcode(),
987+
# TODO: Filtering by long discriminator seems broken
988+
# filter_type=FilterType.LONG_DISCRIMINATOR,
989+
# filter=self._ota_provider.get_descriminator(),
990+
)
991+
ota_provider_node_id = ota_provider_node.node_id
992+
except NodeCommissionFailed:
993+
LOGGER.error("Failed to commission OTA Provider App!")
994+
return None
995+
LOGGER.info(
996+
"OTA Provider App commissioned with node id %d.",
997+
ota_provider_node_id,
998+
)
999+
1000+
# Adjust ACL of OTA Requestor such that Node peer-to-peer communication
1001+
# is allowed.
1002+
try:
1003+
read_result = await self.chip_controller.ReadAttribute(
1004+
ota_provider_node_id, [(0, Clusters.AccessControl.Attributes.Acl)]
1005+
)
1006+
acl_list = cast(
1007+
list,
1008+
read_result[0][Clusters.AccessControl][
1009+
Clusters.AccessControl.Attributes.Acl
1010+
],
1011+
)
1012+
1013+
# Add new ACL entry...
1014+
acl_list.append(
1015+
Clusters.AccessControl.Structs.AccessControlEntryStruct(
1016+
fabricIndex=1,
1017+
privilege=3,
1018+
authMode=2,
1019+
subjects=Types.NullValue,
1020+
targets=[
1021+
Clusters.AccessControl.Structs.AccessControlTargetStruct(
1022+
cluster=41, endpoint=0, deviceType=Types.NullValue
1023+
)
1024+
],
1025+
)
1026+
)
1027+
1028+
# And write. This is persistent, so only need to be done after we commissioned
1029+
# the OTA Provider App.
1030+
write_result: Attribute.AttributeWriteResult = (
1031+
await self.chip_controller.WriteAttribute(
1032+
ota_provider_node_id,
1033+
[(0, Clusters.AccessControl.Attributes.Acl(acl_list))],
1034+
)
1035+
)
1036+
if write_result[0].Status != Status.Success:
1037+
logging.error("Failed writing adjusted OTA Provider App ACL.")
1038+
await self.remove_node(ota_provider_node_id)
1039+
return None
1040+
except ChipStackError as ex:
1041+
logging.exception("Failed adjusting OTA Provider App ACL.", exc_info=ex)
1042+
await self.remove_node(ota_provider_node_id)
1043+
else:
1044+
self._ota_provider.set_node_id(ota_provider_node_id)
1045+
1046+
# Notify node about the new update!
9681047
await self.chip_controller.SendCommand(
9691048
nodeid=node_id,
9701049
endpoint=0,
9711050
payload=Clusters.OtaSoftwareUpdateRequestor.Commands.AnnounceOTAProvider(
972-
providerNodeID=32,
1051+
providerNodeID=ota_provider_node_id,
9731052
vendorID=0, # TODO: Use Server Vendor ID
9741053
announcementReason=Clusters.OtaSoftwareUpdateRequestor.Enums.AnnouncementReasonEnum.kUpdateAvailable,
9751054
endpoint=0,

matter_server/server/ota/provider.py

+114-39
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import json
77
import logging
88
from pathlib import Path
9+
import secrets
910
from typing import TYPE_CHECKING, Final
1011
from urllib.parse import unquote, urlparse
1112

@@ -37,9 +38,12 @@ class DeviceSoftwareVersionModel: # pylint: disable=C0103
3738

3839

3940
@dataclass
40-
class UpdateFile: # pylint: disable=C0103
41+
class OtaProviderImageList: # pylint: disable=C0103
4142
"""Update File for OTA Provider JSON descriptor file."""
4243

44+
otaProviderDiscriminator: int
45+
otaProviderPasscode: int
46+
otaProviderNodeId: int | None
4347
deviceSoftwareVersionModel: list[DeviceSoftwareVersionModel]
4448

4549

@@ -50,23 +54,103 @@ class ExternalOtaProvider:
5054
for devices.
5155
"""
5256

53-
def __init__(self) -> None:
57+
def __init__(self, ota_provider_dir: Path) -> None:
5458
"""Initialize the OTA provider."""
59+
self._ota_provider_dir: Path = ota_provider_dir
60+
self._ota_provider_image_list_file: Path = ota_provider_dir / "updates.json"
61+
self._ota_provider_image_list: OtaProviderImageList | None = None
5562
self._ota_provider_proc: Process | None = None
5663
self._ota_provider_task: asyncio.Task | None = None
5764

65+
async def initialize(self) -> None:
66+
"""Initialize OTA Provider."""
67+
68+
loop = asyncio.get_event_loop()
69+
70+
# Take existence of image list file as indicator if we need to initialize the
71+
# OTA Provider.
72+
if not await loop.run_in_executor(
73+
None, self._ota_provider_image_list_file.exists
74+
):
75+
await loop.run_in_executor(
76+
None, functools.partial(DEFAULT_UPDATES_PATH.mkdir, exist_ok=True)
77+
)
78+
79+
# Initialize with random data. Node ID will get written once paired by
80+
# device controller.
81+
self._ota_provider_image_list = OtaProviderImageList(
82+
otaProviderDiscriminator=secrets.randbelow(2**12),
83+
otaProviderPasscode=secrets.randbelow(2**21),
84+
otaProviderNodeId=None,
85+
deviceSoftwareVersionModel=[],
86+
)
87+
else:
88+
89+
def _read_update_json(
90+
update_json_path: Path,
91+
) -> None | OtaProviderImageList:
92+
with open(update_json_path, "r") as json_file:
93+
data = json.load(json_file)
94+
return dataclass_from_dict(OtaProviderImageList, data)
95+
96+
self._ota_provider_image_list = await loop.run_in_executor(
97+
None, _read_update_json, self._ota_provider_image_list_file
98+
)
99+
100+
def _get_ota_provider_image_list(self) -> OtaProviderImageList:
101+
if self._ota_provider_image_list is None:
102+
raise RuntimeError("OTA provider image list not initialized.")
103+
return self._ota_provider_image_list
104+
105+
def get_node_id(self) -> int | None:
106+
"""Get Node ID of the OTA Provider App."""
107+
108+
return self._get_ota_provider_image_list().otaProviderNodeId
109+
110+
def get_descriminator(self) -> int:
111+
"""Return OTA Provider App discriminator."""
112+
113+
return self._get_ota_provider_image_list().otaProviderDiscriminator
114+
115+
def get_passcode(self) -> int:
116+
"""Return OTA Provider App passcode."""
117+
118+
return self._get_ota_provider_image_list().otaProviderPasscode
119+
120+
def set_node_id(self, node_id: int) -> None:
121+
"""Set Node ID of the OTA Provider App."""
122+
123+
self._get_ota_provider_image_list().otaProviderNodeId = node_id
124+
58125
async def _start_ota_provider(self) -> None:
59-
# TODO: Randomize discriminator
126+
def _write_ota_provider_image_list_json(
127+
ota_provider_image_list_file: Path,
128+
ota_provider_image_list: OtaProviderImageList,
129+
) -> None:
130+
update_file_dict = asdict(ota_provider_image_list)
131+
with open(ota_provider_image_list_file, "w") as json_file:
132+
json.dump(update_file_dict, json_file, indent=4)
133+
134+
loop = asyncio.get_running_loop()
135+
await loop.run_in_executor(
136+
None,
137+
_write_ota_provider_image_list_json,
138+
self._ota_provider_image_list_file,
139+
self._get_ota_provider_image_list(),
140+
)
141+
60142
ota_provider_cmd = [
61143
"chip-ota-provider-app",
62144
"--discriminator",
63-
"22",
145+
str(self._get_ota_provider_image_list().otaProviderDiscriminator),
146+
"--passcode",
147+
str(self._get_ota_provider_image_list().otaProviderPasscode),
64148
"--secured-device-port",
65149
"5565",
66150
"--KVS",
67-
"/data/chip_kvs_provider",
151+
str(self._ota_provider_dir / "chip_kvs_ota_provider"),
68152
"--otaImageList",
69-
str(DEFAULT_UPDATES_PATH / "updates.json"),
153+
str(self._ota_provider_image_list_file),
70154
]
71155

72156
LOGGER.info("Starting OTA Provider")
@@ -80,40 +164,41 @@ def start(self) -> None:
80164
loop = asyncio.get_event_loop()
81165
self._ota_provider_task = loop.create_task(self._start_ota_provider())
82166

167+
async def reset(self) -> None:
168+
"""Reset the OTA Provider App state."""
169+
170+
def _remove_update_data(ota_provider_dir: Path) -> None:
171+
for path in ota_provider_dir.iterdir():
172+
if not path.is_dir():
173+
path.unlink()
174+
175+
loop = asyncio.get_event_loop()
176+
await loop.run_in_executor(None, _remove_update_data, self._ota_provider_dir)
177+
178+
await self.initialize()
179+
83180
async def stop(self) -> None:
84181
"""Stop the OTA Provider."""
85182
if self._ota_provider_proc:
86183
LOGGER.info("Terminating OTA Provider")
87-
self._ota_provider_proc.terminate()
184+
loop = asyncio.get_event_loop()
185+
try:
186+
await loop.run_in_executor(None, self._ota_provider_proc.terminate)
187+
except ProcessLookupError as ex:
188+
LOGGER.warning("Stopping OTA Provider failed with error:", exc_info=ex)
88189
if self._ota_provider_task:
89190
await self._ota_provider_task
90191

91192
async def add_update(self, update_desc: dict, ota_file: Path) -> None:
92193
"""Add update to the OTA provider."""
93194

94-
update_json_path = DEFAULT_UPDATES_PATH / "updates.json"
95-
96-
def _read_update_json(update_json_path: Path) -> None | UpdateFile:
97-
if not update_json_path.exists():
98-
return None
99-
100-
with open(update_json_path, "r") as json_file:
101-
data = json.load(json_file)
102-
return dataclass_from_dict(UpdateFile, data)
103-
104-
loop = asyncio.get_running_loop()
105-
update_file = await loop.run_in_executor(
106-
None, _read_update_json, update_json_path
107-
)
108-
109-
if not update_file:
110-
update_file = UpdateFile(deviceSoftwareVersionModel=[])
111-
112195
local_ota_url = str(ota_file)
113-
for i, device_software in enumerate(update_file.deviceSoftwareVersionModel):
196+
for i, device_software in enumerate(
197+
self._get_ota_provider_image_list().deviceSoftwareVersionModel
198+
):
114199
if device_software.otaURL == local_ota_url:
115200
LOGGER.debug("Device software entry exists already, replacing!")
116-
del update_file.deviceSoftwareVersionModel[i]
201+
del self._get_ota_provider_image_list().deviceSoftwareVersionModel[i]
117202

118203
# Convert to OTA Requestor descriptor file
119204
new_device_software = DeviceSoftwareVersionModel(
@@ -127,18 +212,8 @@ def _read_update_json(update_json_path: Path) -> None | UpdateFile:
127212
maxApplicableSoftwareVersion=update_desc["maxApplicableSoftwareVersion"],
128213
otaURL=local_ota_url,
129214
)
130-
update_file.deviceSoftwareVersionModel.append(new_device_software)
131-
132-
def _write_update_json(update_json_path: Path, update_file: UpdateFile) -> None:
133-
update_file_dict = asdict(update_file)
134-
with open(update_json_path, "w") as json_file:
135-
json.dump(update_file_dict, json_file, indent=4)
136-
137-
await loop.run_in_executor(
138-
None,
139-
_write_update_json,
140-
update_json_path,
141-
update_file,
215+
self._get_ota_provider_image_list().deviceSoftwareVersionModel.append(
216+
new_device_software
142217
)
143218

144219
async def download_update(self, update_desc: dict) -> None:

0 commit comments

Comments
 (0)