Skip to content

Commit 6f4fb5c

Browse files
committed
[WIP] python test framework PICS 2.0
1 parent 024b09b commit 6f4fb5c

File tree

5 files changed

+409
-43
lines changed

5 files changed

+409
-43
lines changed

scripts/py_matter_yamltests/matter_yamltests/hooks.py

+5
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,11 @@ def show_prompt(self,
227227
"""
228228
pass
229229

230+
def test_skipped(self, filename: str, name: str):
231+
"""
232+
This method is called when the test script determines that the test is not applicable for the DUT.
233+
"""
234+
230235

231236
class WebSocketRunnerHooks():
232237
def connecting(self, url: str):

src/python_testing/TC_TIMESYNC_2_1.py

+23-31
Original file line numberDiff line numberDiff line change
@@ -27,53 +27,45 @@
2727

2828
import chip.clusters as Clusters
2929
from chip.clusters.Types import NullValue
30-
from matter_testing_support import MatterBaseTest, async_test_body, default_matter_test_main, utc_time_in_matter_epoch
30+
from matter_testing_support import MatterBaseTest, default_matter_test_main, utc_time_in_matter_epoch, per_endpoint_test, has_cluster, has_attribute
3131
from mobly import asserts
3232

3333

3434
class TC_TIMESYNC_2_1(MatterBaseTest):
35-
async def read_ts_attribute_expect_success(self, endpoint, attribute):
35+
async def read_ts_attribute_expect_success(self, attribute):
3636
cluster = Clusters.Objects.TimeSynchronization
37-
return await self.read_single_attribute_check_success(endpoint=endpoint, cluster=cluster, attribute=attribute)
37+
return await self.read_single_attribute_check_success(endpoint=None, cluster=cluster, attribute=attribute)
3838

39-
def pics_TC_TIMESYNC_2_1(self) -> list[str]:
40-
return ["TIMESYNC.S"]
41-
42-
@async_test_body
39+
@per_endpoint_test(has_cluster(Clusters.TimeSynchronization) and has_attribute(Clusters.TimeSynchronization.Attributes.TimeSource))
4340
async def test_TC_TIMESYNC_2_1(self):
44-
endpoint = 0
45-
46-
features = await self.read_single_attribute(dev_ctrl=self.default_controller, node_id=self.dut_node_id,
47-
endpoint=endpoint, attribute=Clusters.TimeSynchronization.Attributes.FeatureMap)
41+
attributes = Clusters.TimeSynchronization.Attributes
42+
features = await self.read_ts_attribute_expect_success(attribute=attributes.FeatureMap)
4843

4944
self.supports_time_zone = bool(features & Clusters.TimeSynchronization.Bitmaps.Feature.kTimeZone)
5045
self.supports_ntpc = bool(features & Clusters.TimeSynchronization.Bitmaps.Feature.kNTPClient)
5146
self.supports_ntps = bool(features & Clusters.TimeSynchronization.Bitmaps.Feature.kNTPServer)
5247
self.supports_trusted_time_source = bool(features & Clusters.TimeSynchronization.Bitmaps.Feature.kTimeSyncClient)
5348

54-
time_cluster = Clusters.TimeSynchronization
55-
timesync_attr_list = time_cluster.Attributes.AttributeList
56-
attribute_list = await self.read_single_attribute_check_success(endpoint=endpoint, cluster=time_cluster, attribute=timesync_attr_list)
57-
timesource_attr_id = time_cluster.Attributes.TimeSource.attribute_id
49+
timesync_attr_list = attributes.AttributeList
50+
attribute_list = await self.read_ts_attribute_expect_success(attribute=timesync_attr_list)
51+
timesource_attr_id = attributes.TimeSource.attribute_id
5852

5953
self.print_step(1, "Commissioning, already done")
60-
attributes = Clusters.TimeSynchronization.Attributes
6154

6255
self.print_step(2, "Read Granularity attribute")
63-
granularity_dut = await self.read_ts_attribute_expect_success(endpoint=endpoint, attribute=attributes.Granularity)
56+
granularity_dut = await self.read_ts_attribute_expect_success(attribute=attributes.Granularity)
6457
asserts.assert_less(granularity_dut, Clusters.TimeSynchronization.Enums.GranularityEnum.kUnknownEnumValue,
6558
"Granularity is not in valid range")
6659

6760
self.print_step(3, "Read TimeSource")
6861
if timesource_attr_id in attribute_list:
69-
time_source = await self.read_ts_attribute_expect_success(endpoint=endpoint, attribute=attributes.TimeSource)
62+
time_source = await self.read_ts_attribute_expect_success(attribute=attributes.TimeSource)
7063
asserts.assert_less(time_source, Clusters.TimeSynchronization.Enums.TimeSourceEnum.kUnknownEnumValue,
7164
"TimeSource is not in valid range")
7265

7366
self.print_step(4, "Read TrustedTimeSource")
7467
if self.supports_trusted_time_source:
75-
trusted_time_source = await self.read_ts_attribute_expect_success(endpoint=endpoint,
76-
attribute=attributes.TrustedTimeSource)
68+
trusted_time_source = await self.read_ts_attribute_expect_success(attribute=attributes.TrustedTimeSource)
7769
if trusted_time_source is not NullValue:
7870
asserts.assert_less_equal(trusted_time_source.fabricIndex, 0xFE,
7971
"FabricIndex for the TrustedTimeSource is out of range")
@@ -82,7 +74,7 @@ async def test_TC_TIMESYNC_2_1(self):
8274

8375
self.print_step(5, "Read DefaultNTP")
8476
if self.supports_ntpc:
85-
default_ntp = await self.read_ts_attribute_expect_success(endpoint=endpoint, attribute=attributes.DefaultNTP)
77+
default_ntp = await self.read_ts_attribute_expect_success(attribute=attributes.DefaultNTP)
8678
if default_ntp is not NullValue:
8779
asserts.assert_less_equal(len(default_ntp), 128, "DefaultNTP length must be less than 128")
8880
# Assume this is a valid web address if it has at least one . in the name
@@ -97,7 +89,7 @@ async def test_TC_TIMESYNC_2_1(self):
9789

9890
self.print_step(6, "Read TimeZone")
9991
if self.supports_time_zone:
100-
tz_dut = await self.read_ts_attribute_expect_success(endpoint=endpoint, attribute=attributes.TimeZone)
92+
tz_dut = await self.read_ts_attribute_expect_success(attribute=attributes.TimeZone)
10193
asserts.assert_greater_equal(len(tz_dut), 1, "TimeZone must have at least one entry in the list")
10294
asserts.assert_less_equal(len(tz_dut), 2, "TimeZone may have a maximum of two entries in the list")
10395
for entry in tz_dut:
@@ -112,7 +104,7 @@ async def test_TC_TIMESYNC_2_1(self):
112104

113105
self.print_step(7, "Read DSTOffset")
114106
if self.supports_time_zone:
115-
dst_dut = await self.read_ts_attribute_expect_success(endpoint=endpoint, attribute=attributes.DSTOffset)
107+
dst_dut = await self.read_ts_attribute_expect_success(attribute=attributes.DSTOffset)
116108
last_valid_until = -1
117109
last_valid_starting = -1
118110
for dst in dst_dut:
@@ -126,7 +118,7 @@ async def test_TC_TIMESYNC_2_1(self):
126118
asserts.assert_equal(dst, dst_dut[-1], "DSTOffset list must have Null ValidUntil at the end")
127119

128120
self.print_step(8, "Read UTCTime")
129-
utc_dut = await self.read_ts_attribute_expect_success(endpoint=endpoint, attribute=attributes.UTCTime)
121+
utc_dut = await self.read_ts_attribute_expect_success(attribute=attributes.UTCTime)
130122
if utc_dut is NullValue:
131123
asserts.assert_equal(granularity_dut, Clusters.TimeSynchronization.Enums.GranularityEnum.kNoTimeGranularity)
132124
else:
@@ -141,8 +133,8 @@ async def test_TC_TIMESYNC_2_1(self):
141133

142134
self.print_step(9, "Read LocalTime")
143135
if self.supports_time_zone:
144-
utc_dut = await self.read_ts_attribute_expect_success(endpoint=endpoint, attribute=attributes.UTCTime)
145-
local_dut = await self.read_ts_attribute_expect_success(endpoint=endpoint, attribute=attributes.LocalTime)
136+
utc_dut = await self.read_ts_attribute_expect_success(attribute=attributes.UTCTime)
137+
local_dut = await self.read_ts_attribute_expect_success(attribute=attributes.LocalTime)
146138
if utc_dut is NullValue:
147139
asserts.assert_true(local_dut is NullValue, "LocalTime must be Null if UTC time is Null")
148140
elif len(dst_dut) == 0:
@@ -156,30 +148,30 @@ async def test_TC_TIMESYNC_2_1(self):
156148

157149
self.print_step(10, "Read TimeZoneDatabase")
158150
if self.supports_time_zone:
159-
tz_db_dut = await self.read_ts_attribute_expect_success(endpoint=endpoint, attribute=attributes.TimeZoneDatabase)
151+
tz_db_dut = await self.read_ts_attribute_expect_success(attribute=attributes.TimeZoneDatabase)
160152
asserts.assert_less(tz_db_dut, Clusters.TimeSynchronization.Enums.TimeZoneDatabaseEnum.kUnknownEnumValue,
161153
"TimeZoneDatabase is not in valid range")
162154

163155
self.print_step(11, "Read NTPServerAvailable")
164156
if self.supports_ntps:
165157
# bool typechecking happens in the test read functions, so all we need to do here is do the read
166-
await self.read_ts_attribute_expect_success(endpoint=endpoint, attribute=attributes.NTPServerAvailable)
158+
await self.read_ts_attribute_expect_success(attribute=attributes.NTPServerAvailable)
167159

168160
self.print_step(12, "Read TimeZoneListMaxSize")
169161
if self.supports_time_zone:
170-
size = await self.read_ts_attribute_expect_success(endpoint=endpoint, attribute=attributes.TimeZoneListMaxSize)
162+
size = await self.read_ts_attribute_expect_success(attribute=attributes.TimeZoneListMaxSize)
171163
asserts.assert_greater_equal(size, 1, "TimeZoneListMaxSize must be at least 1")
172164
asserts.assert_less_equal(size, 2, "TimeZoneListMaxSize must be max 2")
173165

174166
self.print_step(13, "Read DSTOffsetListMaxSize")
175167
if self.supports_time_zone:
176-
size = await self.read_ts_attribute_expect_success(endpoint=endpoint, attribute=attributes.DSTOffsetListMaxSize)
168+
size = await self.read_ts_attribute_expect_success(attribute=attributes.DSTOffsetListMaxSize)
177169
asserts.assert_greater_equal(size, 1, "DSTOffsetListMaxSize must be at least 1")
178170

179171
self.print_step(14, "Read SupportsDNSResolve")
180172
# bool typechecking happens in the test read functions, so all we need to do here is do the read
181173
if self.supports_ntpc:
182-
await self.read_ts_attribute_expect_success(endpoint=endpoint, attribute=attributes.SupportsDNSResolve)
174+
await self.read_ts_attribute_expect_success(attribute=attributes.SupportsDNSResolve)
183175

184176

185177
if __name__ == "__main__":

src/python_testing/matter_testing_support.py

+93-7
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from dataclasses import dataclass, field
3535
from datetime import datetime, timedelta, timezone
3636
from enum import Enum
37+
from functools import partial
3738
from typing import List, Optional, Tuple
3839

3940
from chip.tlv import float32, uint
@@ -339,6 +340,8 @@ def show_prompt(self,
339340
placeholder: Optional[str] = None,
340341
default_value: Optional[str] = None) -> None:
341342
pass
343+
def test_skipped(self, filename: str, name: str):
344+
logging.info(f"Skipping test from {filename}: {name}")
342345

343346

344347
@dataclass
@@ -771,8 +774,10 @@ def setup_class(self):
771774

772775
def setup_test(self):
773776
self.current_step_index = 0
777+
self.test_start_time = datetime.now(timezone.utc)
774778
self.step_start_time = datetime.now(timezone.utc)
775779
self.step_skipped = False
780+
self.failed = False
776781
if self.runner_hook and not self.is_commissioning:
777782
test_name = self.current_test_info.name
778783
steps = self.get_defined_test_steps(test_name)
@@ -949,12 +954,11 @@ def on_fail(self, record):
949954
950955
record is of type TestResultRecord
951956
'''
957+
self.failed = True
952958
if self.runner_hook and not self.is_commissioning:
953959
exception = record.termination_signal.exception
954960
step_duration = (datetime.now(timezone.utc) - self.step_start_time) / timedelta(microseconds=1)
955-
# This isn't QUITE the test duration because the commissioning is handled separately, but it's clsoe enough for now
956-
# This is already given in milliseconds
957-
test_duration = record.end_time - record.begin_time
961+
test_duration = datetime.now(timezone.utc) - self.test_start_time
958962
# TODO: I have no idea what logger, logs, request or received are. Hope None works because I have nothing to give
959963
self.runner_hook.step_failure(logger=None, logs=None, duration=step_duration, request=None, received=None)
960964
self.runner_hook.test_stop(exception=exception, duration=test_duration)
@@ -968,7 +972,7 @@ def on_pass(self, record):
968972
# What is request? This seems like an implementation detail for the runner
969973
# TODO: As with failure, I have no idea what logger, logs or request are meant to be
970974
step_duration = (datetime.now(timezone.utc) - self.step_start_time) / timedelta(microseconds=1)
971-
test_duration = record.end_time - record.begin_time
975+
test_duration = datetime.now(timezone.utc) - self.test_start_time
972976
self.runner_hook.step_success(logger=None, logs=None, duration=step_duration, request=None)
973977

974978
# TODO: this check could easily be annoying when doing dev. flag it somehow? Ditto with the in-order check
@@ -986,6 +990,18 @@ def on_pass(self, record):
986990
if self.runner_hook and not self.is_commissioning:
987991
self.runner_hook.test_stop(exception=None, duration=test_duration)
988992

993+
def on_skip(self, record):
994+
''' Called by Mobly on test skip
995+
996+
record is of type TestResultRecord
997+
'''
998+
if self.runner_hook and not self.is_commissioning:
999+
test_duration = record.end_time - record.begin_time
1000+
test_name = self.current_test_info.name
1001+
filename = inspect.getfile(self.__class__)
1002+
self.runner_hook.test_skipped(filename, test_name)
1003+
self.runner_hook.test_stop(exception=None, duration=test_duration)
1004+
9891005
def pics_guard(self, pics_condition: bool):
9901006
"""Checks a condition and if False marks the test step as skipped and
9911007
returns False, otherwise returns True.
@@ -1531,6 +1547,10 @@ def parse_matter_test_args(argv: Optional[List[str]] = None) -> MatterTestConfig
15311547

15321548
return convert_args_to_matter_config(parser.parse_known_args(argv)[0])
15331549

1550+
def _async_runner(body, self: MatterBaseTest, *args, **kwargs):
1551+
timeout = self.matter_test_config.timeout if self.matter_test_config.timeout is not None else self.default_timeout
1552+
runner_with_timeout = asyncio.wait_for(body(self, *args, **kwargs), timeout=timeout)
1553+
return asyncio.run(runner_with_timeout)
15341554

15351555
def async_test_body(body):
15361556
"""Decorator required to be applied whenever a `test_*` method is `async def`.
@@ -1541,12 +1561,78 @@ def async_test_body(body):
15411561
"""
15421562

15431563
def async_runner(self: MatterBaseTest, *args, **kwargs):
1544-
timeout = self.matter_test_config.timeout if self.matter_test_config.timeout is not None else self.default_timeout
1545-
runner_with_timeout = asyncio.wait_for(body(self, *args, **kwargs), timeout=timeout)
1546-
return asyncio.run(runner_with_timeout)
1564+
return _async_runner(body, self, *args, **kwargs)
15471565

15481566
return async_runner
15491567

1568+
def per_node_test(body):
1569+
1570+
""" Decorator to be used for PICS-free tests that apply to the entire node.
1571+
1572+
Use this decorator when your script needs to be run once to validate the whole node.
1573+
To use this decorator, the test must NOT have an associated pics_ method.
1574+
"""
1575+
def whole_node_runner(self: MatterBaseTest, *args, **kwargs):
1576+
asserts.assert_false(self.get_test_pics(self.current_test_info.name), "pics_ method supplied for per_node_test.")
1577+
return _async_runner(body, self, *args, **kwargs)
1578+
1579+
return whole_node_runner
1580+
1581+
EndpointCheckFunction = typing.Callable[[Clusters.Attribute.AsyncReadTransaction.ReadResponse, int], bool]
1582+
1583+
def _has_cluster(wildcard, endpoint, cluster: ClusterObjects.Cluster) -> bool:
1584+
try:
1585+
return cluster in wildcard.attributes[endpoint]
1586+
except KeyError:
1587+
return False
1588+
1589+
def has_cluster(cluster: ClusterObjects.ClusterObjectDescriptor) -> EndpointCheckFunction:
1590+
return partial(_has_cluster, cluster=cluster)
1591+
1592+
def _has_attribute(wildcard, endpoint, attribute: ClusterObjects.ClusterAttributeDescriptor) -> bool:
1593+
cluster = getattr(Clusters, attribute.__qualname__.split('.')[-3])
1594+
try:
1595+
attr_list = wildcard.attributes[endpoint][cluster][cluster.Attributes.AttributeList]
1596+
return attribute.attribute_id in attr_list
1597+
except KeyError:
1598+
return False
1599+
1600+
def has_attribute(attribute: ClusterObjects.ClusterAttributeDescriptor) -> EndpointCheckFunction:
1601+
return partial(_has_attribute, attribute=attribute)
1602+
1603+
async def get_accepted_endpoints_for_test(self:MatterBaseTest, accept_function: EndpointCheckFunction):
1604+
wildcard = await self.default_controller.Read(self.dut_node_id, [()])
1605+
return [e for e in wildcard.attributes.keys() if accept_function(wildcard, e)]
1606+
1607+
def per_endpoint_test(accept_function):
1608+
def per_endpoint_test_internal(body):
1609+
def per_endpoint_runner(self: MatterBaseTest, *args, **kwargs):
1610+
asserts.assert_false(self.get_test_pics(self.current_test_info.name), "pics_ method supplied for per_endpoint_test.")
1611+
runner_with_timeout = asyncio.wait_for(get_accepted_endpoints_for_test(self, accept_function), timeout=5)
1612+
endpoints = asyncio.run(runner_with_timeout)
1613+
if not endpoints:
1614+
logging.info("No matching endpoints found - skipping test")
1615+
asserts.skip('No endpoints match requirements')
1616+
return
1617+
logging.info(f"Running test on the following endpoints: {endpoints}")
1618+
# setup_class is meant to be called once, but setup_test is expected to be run before
1619+
# each iteration. Mobly will run it for us the first time, but since we're running this
1620+
# more than one time, we want to make sure we reset everything as expected.
1621+
# Ditto for teardown - we want to tear down after each iteration, and we want to notify the hool that
1622+
# the test iteration is stopped. test_stop is called by on_pass or on_fail during the last iteration or
1623+
# on failure.
1624+
for e in endpoints:
1625+
logging.info(f'Running test on endpoint {e}')
1626+
if e != endpoints[0]:
1627+
self.setup_test()
1628+
self.matter_test_config.endpoint = e
1629+
_async_runner(body, self, *args, **kwargs)
1630+
if e != endpoints[-1] and not self.failed:
1631+
self.teardown_test()
1632+
self.runner_hook.test_stop(exception=None, duration=datetime.now(timezone.utc) - self.test_start_time)
1633+
1634+
return per_endpoint_runner
1635+
return per_endpoint_test_internal
15501636

15511637
class CommissionDeviceTest(MatterBaseTest):
15521638
"""Test class auto-injected at the start of test list to commission a device when requested"""

src/python_testing/test_testing/MockTestRunner.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -37,21 +37,24 @@ async def __call__(self, *args, **kwargs):
3737

3838

3939
class MockTestRunner():
40-
def __init__(self, filename: str, classname: str, test: str, endpoint: int, pics: dict[str, bool] = {}):
41-
self.config = MatterTestConfig(
42-
tests=[test], endpoint=endpoint, dut_node_ids=[1], pics=pics)
40+
def __init__(self, filename: str, classname: str, test: str, endpoint: int = 0, pics: dict[str, bool] = {}):
41+
self.set_test(filename, classname, test, endpoint, pics)
4342
self.stack = MatterStackState(self.config)
4443
self.default_controller = self.stack.certificate_authorities[0].adminList[0].NewController(
4544
nodeId=self.config.controller_node_id,
4645
paaTrustStorePath=str(self.config.paa_trust_store_path),
4746
catTags=self.config.controller_cat_tags
4847
)
48+
49+
def set_test(self, filename: str, classname: str, test: str, endpoint: int = 0, pics: dict[str, bool] = {}):
50+
self.config = MatterTestConfig(
51+
tests=[test], endpoint=endpoint, dut_node_ids=[1], pics=pics)
4952
module = importlib.import_module(Path(os.path.basename(filename)).stem)
5053
self.test_class = getattr(module, classname)
5154

5255
def Shutdown(self):
5356
self.stack.Shutdown()
5457

55-
def run_test_with_mock_read(self, read_cache: Attribute.AsyncReadTransaction.ReadResponse):
58+
def run_test_with_mock_read(self, read_cache: Attribute.AsyncReadTransaction.ReadResponse, hooks = None):
5659
self.default_controller.Read = AsyncMock(return_value=read_cache)
57-
return run_tests_no_exit(self.test_class, self.config, None, self.default_controller, self.stack)
60+
return run_tests_no_exit(self.test_class, self.config, hooks, self.default_controller, self.stack)

0 commit comments

Comments
 (0)