Skip to content

Commit 9d7717f

Browse files
committed
Setup OTA Provider App automatically when necessary
Start and commission OTA Provider App when necessary. Use random discriminator and passcode. Store the Node ID of the OTA Provider App once setup for fast re-use.
1 parent 2bab7e3 commit 9d7717f

File tree

5 files changed

+221
-46
lines changed

5 files changed

+221
-46
lines changed

matter_server/server/__main__.py

+7
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,12 @@
116116
nargs="+",
117117
help="List of node IDs to show logs from (applies only to server logs).",
118118
)
119+
parser.add_argument(
120+
"--ota-provider-dir",
121+
type=str,
122+
default=None,
123+
help="Directory where OTA Provider stores software updates and configuration.",
124+
)
119125

120126
args = parser.parse_args()
121127

@@ -216,6 +222,7 @@ def main() -> None:
216222
args.paa_root_cert_dir,
217223
args.enable_test_net_dcl,
218224
args.bluetooth_adapter,
225+
args.ota_provider_dir,
219226
)
220227

221228
async def handle_stop(loop: asyncio.AbstractEventLoop) -> None:

matter_server/server/const.py

+2
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,5 @@
2020
.parent.resolve()
2121
.joinpath("credentials/development/paa-root-certs")
2222
)
23+
24+
DEFAULT_OTA_PROVIDER_DIR: Final[pathlib.Path] = pathlib.Path().cwd().joinpath("updates")

matter_server/server/device_controller.py

+83-3
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,12 @@
1818
from typing import TYPE_CHECKING, Any, cast
1919

2020
from chip.ChipDeviceCtrl import ChipDeviceController
21-
from chip.clusters import Attribute, Objects as Clusters
21+
from chip.clusters import Attribute, Objects as Clusters, Types
2222
from chip.clusters.Attribute import ValueDecodeFailure
2323
from chip.clusters.ClusterObjects import ALL_ATTRIBUTES, ALL_CLUSTERS, Cluster
2424
from chip.discovery import DiscoveryType
2525
from chip.exceptions import ChipStackError
26+
from chip.interaction_model import Status
2627
from zeroconf import BadTypeInNameException, IPVersion, ServiceStateChange, Zeroconf
2728
from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf
2829

@@ -120,6 +121,7 @@ def __init__(
120121
self,
121122
server: MatterServer,
122123
paa_root_cert_dir: Path,
124+
ota_provider_dir: Path,
123125
):
124126
"""Initialize the device controller."""
125127
self.server = server
@@ -150,14 +152,15 @@ def __init__(
150152
self._polled_attributes: dict[int, set[str]] = {}
151153
self._custom_attribute_poller_timer: asyncio.TimerHandle | None = None
152154
self._custom_attribute_poller_task: asyncio.Task | None = None
153-
self._ota_provider = ExternalOtaProvider()
155+
self._ota_provider = ExternalOtaProvider(ota_provider_dir)
154156

155157
async def initialize(self) -> None:
156158
"""Initialize the device controller."""
157159
self._compressed_fabric_id = (
158160
await self._chip_device_controller.get_compressed_fabric_id()
159161
)
160162
self._fabric_id_hex = hex(self._compressed_fabric_id)[2:]
163+
await self._ota_provider.initialize()
161164

162165
async def start(self) -> None:
163166
"""Handle logic on controller start."""
@@ -943,17 +946,94 @@ async def update_node(self, node_id: int) -> dict | None:
943946
# Add to OTA provider
944947
await self._ota_provider.download_update(update)
945948

949+
ota_provider_node_id = self._ota_provider.get_node_id()
950+
if ota_provider_node_id not in self._nodes:
951+
LOGGER.warning(
952+
"OTA Provider node id %d no longer exists! Resetting...",
953+
ota_provider_node_id,
954+
)
955+
await self._ota_provider.reset()
956+
ota_provider_node_id = None
957+
958+
# Make sure any previous instances get stopped
959+
await self._ota_provider.stop()
946960
self._ota_provider.start()
947961

948962
# Wait for OTA provider to be ready
949963
# TODO: Detect when OTA provider is ready
950964
await asyncio.sleep(2)
951965

966+
if not ota_provider_node_id:
967+
# The OTA Provider has not been commissioned yet, let's do it now.
968+
LOGGER.info("Commissioning the built-in OTA Provider App.")
969+
try:
970+
ota_provider_node = await self.commission_on_network(
971+
self._ota_provider.get_passcode(),
972+
# TODO: Filtering by long discriminator seems broken
973+
# filter_type=FilterType.LONG_DISCRIMINATOR,
974+
# filter=self._ota_provider.get_descriminator(),
975+
)
976+
ota_provider_node_id = ota_provider_node.node_id
977+
except NodeCommissionFailed:
978+
LOGGER.error("Failed to commission OTA Provider App!")
979+
return None
980+
LOGGER.info(
981+
"OTA Provider App commissioned with node id %d.",
982+
ota_provider_node_id,
983+
)
984+
985+
# Adjust ACL of OTA Requestor such that Node peer-to-peer communication
986+
# is allowed.
987+
try:
988+
read_result = await self.chip_controller.ReadAttribute(
989+
ota_provider_node_id, [(0, Clusters.AccessControl.Attributes.Acl)]
990+
)
991+
acl_list = cast(
992+
list,
993+
read_result[0][Clusters.AccessControl][
994+
Clusters.AccessControl.Attributes.Acl
995+
],
996+
)
997+
998+
# Add new ACL entry...
999+
acl_list.append(
1000+
Clusters.AccessControl.Structs.AccessControlEntryStruct(
1001+
fabricIndex=1,
1002+
privilege=3,
1003+
authMode=2,
1004+
subjects=Types.NullValue,
1005+
targets=[
1006+
Clusters.AccessControl.Structs.AccessControlTargetStruct(
1007+
cluster=41, endpoint=0, deviceType=Types.NullValue
1008+
)
1009+
],
1010+
)
1011+
)
1012+
1013+
# And write. This is persistent, so only need to be done after we commissioned
1014+
# the OTA Provider App.
1015+
write_result: Attribute.AttributeWriteResult = (
1016+
await self.chip_controller.WriteAttribute(
1017+
ota_provider_node_id,
1018+
[(0, Clusters.AccessControl.Attributes.Acl(acl_list))],
1019+
)
1020+
)
1021+
if write_result[0].Status != Status.Success:
1022+
logging.error("Failed writing adjusted OTA Provider App ACL.")
1023+
await self.remove_node(ota_provider_node_id)
1024+
return None
1025+
except ChipStackError as ex:
1026+
logging.exception("Failed adjusting OTA Provider App ACL.", exc_info=ex)
1027+
await self.remove_node(ota_provider_node_id)
1028+
else:
1029+
self._ota_provider.set_node_id(ota_provider_node_id)
1030+
1031+
# Notify node about the new update!
9521032
await self.chip_controller.SendCommand(
9531033
nodeid=node_id,
9541034
endpoint=0,
9551035
payload=Clusters.OtaSoftwareUpdateRequestor.Commands.AnnounceOTAProvider(
956-
providerNodeID=32,
1036+
providerNodeID=ota_provider_node_id,
9571037
vendorID=0, # TODO: Use Server Vendor ID
9581038
announcementReason=Clusters.OtaSoftwareUpdateRequestor.Enums.AnnouncementReasonEnum.kUpdateAvailable,
9591039
endpoint=0,

matter_server/server/ota/provider.py

+114-39
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import json
77
import logging
88
from pathlib import Path
9+
import secrets
910
from typing import TYPE_CHECKING, Final
1011
from urllib.parse import unquote, urlparse
1112

@@ -37,9 +38,12 @@ class DeviceSoftwareVersionModel: # pylint: disable=C0103
3738

3839

3940
@dataclass
40-
class UpdateFile: # pylint: disable=C0103
41+
class OtaProviderImageList: # pylint: disable=C0103
4142
"""Update File for OTA Provider JSON descriptor file."""
4243

44+
otaProviderDiscriminator: int
45+
otaProviderPasscode: int
46+
otaProviderNodeId: int | None
4347
deviceSoftwareVersionModel: list[DeviceSoftwareVersionModel]
4448

4549

@@ -50,23 +54,103 @@ class ExternalOtaProvider:
5054
for devices.
5155
"""
5256

53-
def __init__(self) -> None:
57+
def __init__(self, ota_provider_dir: Path) -> None:
5458
"""Initialize the OTA provider."""
59+
self._ota_provider_dir: Path = ota_provider_dir
60+
self._ota_provider_image_list_file: Path = ota_provider_dir / "updates.json"
61+
self._ota_provider_image_list: OtaProviderImageList | None = None
5562
self._ota_provider_proc: Process | None = None
5663
self._ota_provider_task: asyncio.Task | None = None
5764

65+
async def initialize(self) -> None:
66+
"""Initialize OTA Provider."""
67+
68+
loop = asyncio.get_event_loop()
69+
70+
# Take existence of image list file as indicator if we need to initialize the
71+
# OTA Provider.
72+
if not await loop.run_in_executor(
73+
None, self._ota_provider_image_list_file.exists
74+
):
75+
await loop.run_in_executor(
76+
None, functools.partial(DEFAULT_UPDATES_PATH.mkdir, exist_ok=True)
77+
)
78+
79+
# Initialize with random data. Node ID will get written once paired by
80+
# device controller.
81+
self._ota_provider_image_list = OtaProviderImageList(
82+
otaProviderDiscriminator=secrets.randbelow(2**12),
83+
otaProviderPasscode=secrets.randbelow(2**21),
84+
otaProviderNodeId=None,
85+
deviceSoftwareVersionModel=[],
86+
)
87+
else:
88+
89+
def _read_update_json(
90+
update_json_path: Path,
91+
) -> None | OtaProviderImageList:
92+
with open(update_json_path, "r") as json_file:
93+
data = json.load(json_file)
94+
return dataclass_from_dict(OtaProviderImageList, data)
95+
96+
self._ota_provider_image_list = await loop.run_in_executor(
97+
None, _read_update_json, self._ota_provider_image_list_file
98+
)
99+
100+
def _get_ota_provider_image_list(self) -> OtaProviderImageList:
101+
if self._ota_provider_image_list is None:
102+
raise RuntimeError("OTA provider image list not initialized.")
103+
return self._ota_provider_image_list
104+
105+
def get_node_id(self) -> int | None:
106+
"""Get Node ID of the OTA Provider App."""
107+
108+
return self._get_ota_provider_image_list().otaProviderNodeId
109+
110+
def get_descriminator(self) -> int:
111+
"""Return OTA Provider App discriminator."""
112+
113+
return self._get_ota_provider_image_list().otaProviderDiscriminator
114+
115+
def get_passcode(self) -> int:
116+
"""Return OTA Provider App passcode."""
117+
118+
return self._get_ota_provider_image_list().otaProviderPasscode
119+
120+
def set_node_id(self, node_id: int) -> None:
121+
"""Set Node ID of the OTA Provider App."""
122+
123+
self._get_ota_provider_image_list().otaProviderNodeId = node_id
124+
58125
async def _start_ota_provider(self) -> None:
59-
# TODO: Randomize discriminator
126+
def _write_ota_provider_image_list_json(
127+
ota_provider_image_list_file: Path,
128+
ota_provider_image_list: OtaProviderImageList,
129+
) -> None:
130+
update_file_dict = asdict(ota_provider_image_list)
131+
with open(ota_provider_image_list_file, "w") as json_file:
132+
json.dump(update_file_dict, json_file, indent=4)
133+
134+
loop = asyncio.get_running_loop()
135+
await loop.run_in_executor(
136+
None,
137+
_write_ota_provider_image_list_json,
138+
self._ota_provider_image_list_file,
139+
self._get_ota_provider_image_list(),
140+
)
141+
60142
ota_provider_cmd = [
61143
"chip-ota-provider-app",
62144
"--discriminator",
63-
"22",
145+
str(self._get_ota_provider_image_list().otaProviderDiscriminator),
146+
"--passcode",
147+
str(self._get_ota_provider_image_list().otaProviderPasscode),
64148
"--secured-device-port",
65149
"5565",
66150
"--KVS",
67-
"/data/chip_kvs_provider",
151+
str(self._ota_provider_dir / "chip_kvs_ota_provider"),
68152
"--otaImageList",
69-
str(DEFAULT_UPDATES_PATH / "updates.json"),
153+
str(self._ota_provider_image_list_file),
70154
]
71155

72156
LOGGER.info("Starting OTA Provider")
@@ -80,40 +164,41 @@ def start(self) -> None:
80164
loop = asyncio.get_event_loop()
81165
self._ota_provider_task = loop.create_task(self._start_ota_provider())
82166

167+
async def reset(self) -> None:
168+
"""Reset the OTA Provider App state."""
169+
170+
def _remove_update_data(ota_provider_dir: Path) -> None:
171+
for path in ota_provider_dir.iterdir():
172+
if not path.is_dir():
173+
path.unlink()
174+
175+
loop = asyncio.get_event_loop()
176+
await loop.run_in_executor(None, _remove_update_data, self._ota_provider_dir)
177+
178+
await self.initialize()
179+
83180
async def stop(self) -> None:
84181
"""Stop the OTA Provider."""
85182
if self._ota_provider_proc:
86183
LOGGER.info("Terminating OTA Provider")
87-
self._ota_provider_proc.terminate()
184+
loop = asyncio.get_event_loop()
185+
try:
186+
await loop.run_in_executor(None, self._ota_provider_proc.terminate)
187+
except ProcessLookupError as ex:
188+
LOGGER.warning("Stopping OTA Provider failed with error:", exc_info=ex)
88189
if self._ota_provider_task:
89190
await self._ota_provider_task
90191

91192
async def add_update(self, update_desc: dict, ota_file: Path) -> None:
92193
"""Add update to the OTA provider."""
93194

94-
update_json_path = DEFAULT_UPDATES_PATH / "updates.json"
95-
96-
def _read_update_json(update_json_path: Path) -> None | UpdateFile:
97-
if not update_json_path.exists():
98-
return None
99-
100-
with open(update_json_path, "r") as json_file:
101-
data = json.load(json_file)
102-
return dataclass_from_dict(UpdateFile, data)
103-
104-
loop = asyncio.get_running_loop()
105-
update_file = await loop.run_in_executor(
106-
None, _read_update_json, update_json_path
107-
)
108-
109-
if not update_file:
110-
update_file = UpdateFile(deviceSoftwareVersionModel=[])
111-
112195
local_ota_url = str(ota_file)
113-
for i, device_software in enumerate(update_file.deviceSoftwareVersionModel):
196+
for i, device_software in enumerate(
197+
self._get_ota_provider_image_list().deviceSoftwareVersionModel
198+
):
114199
if device_software.otaURL == local_ota_url:
115200
LOGGER.debug("Device software entry exists already, replacing!")
116-
del update_file.deviceSoftwareVersionModel[i]
201+
del self._get_ota_provider_image_list().deviceSoftwareVersionModel[i]
117202

118203
# Convert to OTA Requestor descriptor file
119204
new_device_software = DeviceSoftwareVersionModel(
@@ -127,18 +212,8 @@ def _read_update_json(update_json_path: Path) -> None | UpdateFile:
127212
maxApplicableSoftwareVersion=update_desc["maxApplicableSoftwareVersion"],
128213
otaURL=local_ota_url,
129214
)
130-
update_file.deviceSoftwareVersionModel.append(new_device_software)
131-
132-
def _write_update_json(update_json_path: Path, update_file: UpdateFile) -> None:
133-
update_file_dict = asdict(update_file)
134-
with open(update_json_path, "w") as json_file:
135-
json.dump(update_file_dict, json_file, indent=4)
136-
137-
await loop.run_in_executor(
138-
None,
139-
_write_update_json,
140-
update_json_path,
141-
update_file,
215+
self._get_ota_provider_image_list().deviceSoftwareVersionModel.append(
216+
new_device_software
142217
)
143218

144219
async def download_update(self, update_desc: dict) -> None:

0 commit comments

Comments
 (0)