Skip to content

Commit 93f3894

Browse files
committed
Split update WebSocket command into two commands
Make check_node_update a separate WebSocket command which only checks for updates. The update_node command then will download and actually apply the update.
1 parent f698b51 commit 93f3894

File tree

5 files changed

+161
-94
lines changed

5 files changed

+161
-94
lines changed

matter_server/common/models.py

+1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class APICommand(str, Enum):
4747
PING_NODE = "ping_node"
4848
GET_NODE_IP_ADDRESSES = "get_node_ip_addresses"
4949
IMPORT_TEST_NODE = "import_test_node"
50+
CHECK_NODE_UPDATE = "check_node_update"
5051
UPDATE_NODE = "update_node"
5152

5253

matter_server/server/device_controller.py

+62-32
Original file line numberDiff line numberDiff line change
@@ -896,8 +896,8 @@ async def import_test_node(self, dump: str) -> None:
896896
self._nodes[node.node_id] = node
897897
self.server.signal_event(EventType.NODE_ADDED, node)
898898

899-
@api_command(APICommand.UPDATE_NODE)
900-
async def update_node(self, node_id: int) -> dict | None:
899+
@api_command(APICommand.CHECK_NODE_UPDATE)
900+
async def check_node_update(self, node_id: int) -> dict | None:
901901
"""
902902
Check if there is an update for a particular node.
903903
@@ -906,8 +906,27 @@ async def update_node(self, node_id: int) -> dict | None:
906906
information of the latest update available.
907907
"""
908908

909-
node_logger = LOGGER.getChild(f"node_{node_id}")
910-
node = self._nodes[node_id]
909+
return await self._check_node_update(node_id)
910+
911+
@api_command(APICommand.UPDATE_NODE)
912+
async def update_node(self, node_id: int, software_version: int) -> dict | None:
913+
"""
914+
Update a node to a new software version.
915+
916+
This command checks if the requested software version is indeed still available
917+
and if so, it will start the update process. The update process will be handled
918+
by the built-in OTA provider. The OTA provider will download the update and
919+
notify the node about the new update.
920+
"""
921+
922+
update = await self._check_node_update(node_id, software_version)
923+
if update is None:
924+
logging.error(
925+
"Software version %d is not available for node %d",
926+
software_version,
927+
node_id,
928+
)
929+
return None
911930

912931
if self.chip_controller is None:
913932
raise RuntimeError("Device Controller not initialized.")
@@ -916,34 +935,7 @@ async def update_node(self, node_id: int) -> dict | None:
916935
LOGGER.warning("No OTA provider found, updates not possible.")
917936
return None
918937

919-
node_logger.debug("Check for updates.")
920-
vid = cast(int, node.attributes.get(BASIC_INFORMATION_VENDOR_ID_ATTRIBUTE_PATH))
921-
pid = cast(
922-
int, node.attributes.get(BASIC_INFORMATION_PRODUCT_ID_ATTRIBUTE_PATH)
923-
)
924-
software_version = cast(
925-
int, node.attributes.get(BASIC_INFORMATION_SOFTWARE_VERSION_ATTRIBUTE_PATH)
926-
)
927-
software_version_string = node.attributes.get(
928-
BASIC_INFORMATION_SOFTWARE_VERSION_STRING_ATTRIBUTE_PATH
929-
)
930-
931-
update = await check_for_update(vid, pid, software_version)
932-
if not update:
933-
node_logger.info("No new update found.")
934-
return None
935-
936-
if "otaUrl" not in update:
937-
node_logger.warning("Update found, but no OTA URL provided.")
938-
return None
939-
940-
node_logger.info(
941-
"New software update found: %s (current %s). Preparing updates...",
942-
update["softwareVersionString"],
943-
software_version_string,
944-
)
945-
946-
# Add to OTA provider
938+
# Add update to the OTA provider
947939
await self._ota_provider.download_update(update)
948940

949941
ota_provider_node_id = self._ota_provider.get_node_id()
@@ -1042,6 +1034,44 @@ async def update_node(self, node_id: int) -> dict | None:
10421034

10431035
return update
10441036

1037+
async def _check_node_update(
1038+
self,
1039+
node_id: int,
1040+
requested_software_version: int | None = None,
1041+
) -> dict | None:
1042+
node_logger = LOGGER.getChild(f"node_{node_id}")
1043+
node = self._nodes[node_id]
1044+
1045+
node_logger.debug("Check for updates.")
1046+
vid = cast(int, node.attributes.get(BASIC_INFORMATION_VENDOR_ID_ATTRIBUTE_PATH))
1047+
pid = cast(
1048+
int, node.attributes.get(BASIC_INFORMATION_PRODUCT_ID_ATTRIBUTE_PATH)
1049+
)
1050+
software_version = cast(
1051+
int, node.attributes.get(BASIC_INFORMATION_SOFTWARE_VERSION_ATTRIBUTE_PATH)
1052+
)
1053+
software_version_string = node.attributes.get(
1054+
BASIC_INFORMATION_SOFTWARE_VERSION_STRING_ATTRIBUTE_PATH
1055+
)
1056+
1057+
update = await check_for_update(
1058+
vid, pid, software_version, requested_software_version
1059+
)
1060+
if not update:
1061+
node_logger.info("No new update found.")
1062+
return None
1063+
1064+
if "otaUrl" not in update:
1065+
node_logger.warning("Update found, but no OTA URL provided.")
1066+
return None
1067+
1068+
node_logger.info(
1069+
"New software update found: %s (current %s).",
1070+
update["softwareVersionString"],
1071+
software_version_string,
1072+
)
1073+
return update
1074+
10451075
async def _subscribe_node(self, node_id: int) -> None:
10461076
"""
10471077
Subscribe to all node state changes/events for an individual node.

matter_server/server/ota/__init__.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,18 @@
1919

2020

2121
async def check_for_update(
22-
vid: int, pid: int, current_software_version: int
22+
vid: int,
23+
pid: int,
24+
current_software_version: int,
25+
requested_software_version: int | None = None,
2326
) -> None | dict:
2427
"""Check for software updates."""
2528
if (vid, pid) in HARDCODED_UPDATES:
26-
return HARDCODED_UPDATES[(vid, pid)]
29+
update = HARDCODED_UPDATES[(vid, pid)]
30+
if (
31+
requested_software_version is None
32+
or update["softwareVersion"] == requested_software_version
33+
):
34+
return update
2735

2836
return await dcl.check_for_update(vid, pid, current_software_version)

matter_server/server/ota/dcl.py

+42-24
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
LOGGER = logging.getLogger(__name__)
1111

1212

13-
async def get_software_versions(vid: int, pid: int) -> Any:
13+
async def _get_software_versions(vid: int, pid: int) -> Any:
1414
"""Check DCL if there are updates available for a particular node."""
1515
async with ClientSession(raise_for_status=True) as http_session:
1616
# fetch the paa certificates list
@@ -20,7 +20,7 @@ async def get_software_versions(vid: int, pid: int) -> Any:
2020
return await response.json()
2121

2222

23-
async def get_software_version(vid: int, pid: int, software_version: int) -> Any:
23+
async def _get_software_version(vid: int, pid: int, software_version: int) -> Any:
2424
"""Check DCL if there are updates available for a particular node."""
2525
async with ClientSession(raise_for_status=True) as http_session:
2626
# fetch the paa certificates list
@@ -30,12 +30,45 @@ async def get_software_version(vid: int, pid: int, software_version: int) -> Any
3030
return await response.json()
3131

3232

33+
async def _check_update_version(
34+
vid: int, pid: int, version: int, current_software_version: int
35+
) -> None | dict:
36+
version_res: dict = await _get_software_version(vid, pid, version)
37+
if not isinstance(version_res, dict):
38+
raise TypeError("Unexpected DCL response.")
39+
40+
if "modelVersion" not in version_res:
41+
raise ValueError("Unexpected DCL response.")
42+
43+
version_candidate: dict = cast(dict, version_res["modelVersion"])
44+
45+
# Check minApplicableSoftwareVersion/maxApplicableSoftwareVersion
46+
min_sw_version = version_candidate["minApplicableSoftwareVersion"]
47+
max_sw_version = version_candidate["maxApplicableSoftwareVersion"]
48+
if (
49+
current_software_version < min_sw_version
50+
or current_software_version > max_sw_version
51+
):
52+
return None
53+
54+
return version_candidate
55+
56+
3357
async def check_for_update(
34-
vid: int, pid: int, current_software_version: int
58+
vid: int,
59+
pid: int,
60+
current_software_version: int,
61+
requested_software_version: int | None = None,
3562
) -> None | dict:
36-
"""Check if there is a newer software version available on the DCL."""
63+
"""Check if there is a software update available on the DCL."""
3764
try:
38-
versions = await get_software_versions(vid, pid)
65+
if requested_software_version is not None:
66+
return await _check_update_version(
67+
vid, pid, requested_software_version, current_software_version
68+
)
69+
70+
# Get all versions and check each one of them.
71+
versions = await _get_software_versions(vid, pid)
3972

4073
all_software_versions: list[int] = versions["modelVersions"]["softwareVersions"]
4174
newer_software_versions = [
@@ -51,26 +84,11 @@ async def check_for_update(
5184

5285
# Check if latest firmware is applicable, and backtrack from there
5386
for version in sorted(newer_software_versions, reverse=True):
54-
version_res: dict = await get_software_version(vid, pid, version)
55-
if not isinstance(version_res, dict):
56-
raise TypeError("Unexpected DCL response.")
57-
58-
if "modelVersion" not in version_res:
59-
raise ValueError("Unexpected DCL response.")
60-
61-
version_candidate: dict = cast(dict, version_res["modelVersion"])
62-
63-
# Check minApplicableSoftwareVersion/maxApplicableSoftwareVersion
64-
min_sw_version = version_candidate["minApplicableSoftwareVersion"]
65-
max_sw_version = version_candidate["maxApplicableSoftwareVersion"]
66-
if (
67-
current_software_version < min_sw_version
68-
or current_software_version > max_sw_version
87+
if version_candidate := await _check_update_version(
88+
vid, pid, version, current_software_version
6989
):
70-
LOGGER.debug("Software version %d not applicable.", version)
71-
continue
72-
73-
return version_candidate
90+
return version_candidate
91+
LOGGER.debug("Software version %d not applicable.", version)
7492
return None
7593

7694
except (ClientError, TimeoutError) as err:

tests/server/ota/test_dcl.py

+46-36
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from unittest.mock import AsyncMock, patch
44

5+
import pytest
6+
57
from matter_server.server.ota.dcl import check_for_update
68

79
# Mock the DCL responses (sample from https://on.dcl.csa-iot.org/dcl/model/versions/4447/8194)
@@ -35,41 +37,49 @@
3537
}
3638

3739

38-
async def test_check_updates():
40+
@pytest.fixture(name="get_software_versions")
41+
def mock_get_software_versions():
42+
"""Mock the _get_software_versions function."""
43+
with patch(
44+
"matter_server.server.ota.dcl._get_software_versions",
45+
new_callable=AsyncMock,
46+
return_value=DCL_RESPONSE_SOFTWARE_VERSIONS,
47+
) as mock:
48+
yield mock
49+
50+
51+
@pytest.fixture(name="get_software_version")
52+
def mock_get_software_version():
53+
"""Mock the _get_software_version function."""
54+
with patch(
55+
"matter_server.server.ota.dcl._get_software_version",
56+
new_callable=AsyncMock,
57+
return_value=DCL_RESPONSE_SOFTWARE_VERSION_1011,
58+
) as mock:
59+
yield mock
60+
61+
62+
async def test_check_updates(get_software_versions, get_software_version):
3963
"""Test the case where the latest software version is applicable."""
40-
with (
41-
patch(
42-
"matter_server.server.ota.dcl.get_software_versions",
43-
new_callable=AsyncMock,
44-
return_value=DCL_RESPONSE_SOFTWARE_VERSIONS,
45-
),
46-
patch(
47-
"matter_server.server.ota.dcl.get_software_version",
48-
new_callable=AsyncMock,
49-
return_value=DCL_RESPONSE_SOFTWARE_VERSION_1011,
50-
),
51-
):
52-
# Call the function with a current software version of 1000
53-
result = await check_for_update(4447, 8194, 1000)
54-
55-
assert result == DCL_RESPONSE_SOFTWARE_VERSION_1011["modelVersion"]
56-
57-
58-
async def test_check_updates_not_applicable():
64+
# Call the function with a current software version of 1000
65+
result = await check_for_update(4447, 8194, 1000)
66+
67+
assert result == DCL_RESPONSE_SOFTWARE_VERSION_1011["modelVersion"]
68+
69+
70+
async def test_check_updates_not_applicable(
71+
get_software_versions, get_software_version
72+
):
5973
"""Test the case where the latest software version is not applicable."""
60-
with (
61-
patch(
62-
"matter_server.server.ota.dcl.get_software_versions",
63-
new_callable=AsyncMock,
64-
return_value=DCL_RESPONSE_SOFTWARE_VERSIONS,
65-
),
66-
patch(
67-
"matter_server.server.ota.dcl.get_software_version",
68-
new_callable=AsyncMock,
69-
return_value=DCL_RESPONSE_SOFTWARE_VERSION_1011,
70-
),
71-
):
72-
# Call the function with a current software version of 1
73-
result = await check_for_update(4447, 8194, 1)
74-
75-
assert result is None
74+
# Call the function with a current software version of 1
75+
result = await check_for_update(4447, 8194, 1)
76+
77+
assert result is None
78+
79+
80+
async def test_check_updates_specific_version(get_software_version):
81+
"""Test the case to get a specific version."""
82+
# Call the function with a current software version of 1000 and request 1011 as update
83+
result = await check_for_update(4447, 8194, 1000, 1011)
84+
85+
assert result == DCL_RESPONSE_SOFTWARE_VERSION_1011["modelVersion"]

0 commit comments

Comments
 (0)