Skip to content

Commit 35239e5

Browse files
authored
Fabric management (#249)
1 parent 945daf8 commit 35239e5

File tree

5 files changed

+168
-2
lines changed

5 files changed

+168
-2
lines changed

matter_server/client/client.py

+42-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import uuid
99

1010
from aiohttp import ClientSession
11+
from chip.clusters import Objects as Clusters
1112

1213
from matter_server.common.errors import ERROR_MAP
1314

@@ -27,7 +28,7 @@
2728
)
2829
from .connection import MatterClientConnection
2930
from .exceptions import ConnectionClosed, InvalidServerVersion, InvalidState
30-
from .models.node import MatterNode
31+
from .models.node import MatterFabricData, MatterNode
3132

3233
if TYPE_CHECKING:
3334
from chip.clusters.Objects import ClusterCommand
@@ -156,6 +157,46 @@ async def open_commissioning_window(
156157
),
157158
)
158159

160+
async def get_matter_fabrics(self, node_id: int) -> list[MatterFabricData]:
161+
"""
162+
Get Matter fabrics from a device.
163+
164+
Returns a list of MatterFabricData objects.
165+
"""
166+
167+
node = await self.get_node(node_id)
168+
fabrics: list[
169+
Clusters.OperationalCredentials.Structs.FabricDescriptor
170+
] = node.get_attribute_value(
171+
0, None, Clusters.OperationalCredentials.Attributes.Fabrics
172+
)
173+
174+
vendors_map = await self.send_command(
175+
APICommand.GET_VENDOR_NAMES,
176+
filter_vendors=[f.vendorId for f in fabrics],
177+
)
178+
179+
return [
180+
MatterFabricData(
181+
fabric_id=f.fabricId,
182+
vendor_id=f.vendorId,
183+
fabric_index=f.fabricIndex,
184+
fabric_label=f.label if f.label else None,
185+
vendor_name=vendors_map.get(f.vendorId),
186+
)
187+
for f in fabrics
188+
]
189+
190+
async def remove_matter_fabric(self, node_id: int, fabric_index: int) -> None:
191+
"""Remove Matter fabric from a device."""
192+
await self.send_device_command(
193+
node_id,
194+
0,
195+
Clusters.OperationalCredentials.Commands.RemoveFabric(
196+
fabricIndex=fabric_index,
197+
),
198+
)
199+
159200
async def send_device_command(
160201
self,
161202
node_id: int,

matter_server/client/models/node.py

+12
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Matter node."""
22
from __future__ import annotations
33

4+
from dataclasses import dataclass
45
import logging
56
from typing import Any, TypeVar, cast
67

@@ -40,6 +41,17 @@ def get_object_params(
4041
raise KeyError(f"No descriptor found for object {object_id}")
4142

4243

44+
@dataclass
45+
class MatterFabricData:
46+
"""Data about a Matter fabric."""
47+
48+
fabric_id: int
49+
vendor_id: int
50+
fabric_index: int
51+
fabric_label: str | None = None
52+
vendor_name: str | None = None
53+
54+
4355
class MatterEndpoint:
4456
"""Representation of a Matter Endpoint."""
4557

matter_server/common/models.py

+13
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,26 @@ class APICommand(str, Enum):
3737
INTERVIEW_NODE = "interview_node"
3838
DEVICE_COMMAND = "device_command"
3939
REMOVE_NODE = "remove_node"
40+
GET_VENDOR_NAMES = "get_vendor_names"
4041

4142

4243
EventCallBackType = Callable[[EventType, Any], None]
4344

4445
# Generic model(s)
4546

4647

48+
@dataclass
49+
class VendorInfo:
50+
"""Vendor info as received from the CSA."""
51+
52+
vendor_id: int
53+
vendor_name: str
54+
company_legal_name: str
55+
company_preferred_name: str
56+
vendor_landing_page_url: str
57+
creator: str
58+
59+
4760
@dataclass
4861
class MatterNodeData:
4962
"""Matter node data as received from (and stored on) the server."""

matter_server/server/server.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from .device_controller import MatterDeviceController
2626
from .stack import MatterStack
2727
from .storage import StorageController
28+
from .vendor_info import VendorInfo
2829

2930

3031
def mount_websocket(server: MatterServer, path: str) -> None:
@@ -75,6 +76,7 @@ def __init__(
7576
# of Matter devices and their subscriptions.
7677
self.device_controller = MatterDeviceController(self)
7778
self.storage = StorageController(self)
79+
self.vendor_info = VendorInfo(self)
7880
# we dynamically register command handlers
7981
self.command_handlers: dict[str, APICommandHandler] = {}
8082
self._subscribers: Set[EventCallBackType] = set()
@@ -92,6 +94,7 @@ async def start(self) -> None:
9294
await self.device_controller.initialize()
9395
await self.storage.start()
9496
await self.device_controller.start()
97+
await self.vendor_info.start()
9598
mount_websocket(self, "/ws")
9699
self.app.router.add_route("GET", "/", self._handle_info)
97100
self._runner = web.AppRunner(self.app, access_log=None)
@@ -174,7 +177,7 @@ def register_api_command(
174177

175178
def _register_api_commands(self) -> None:
176179
"""Register all methods decorated as api_command."""
177-
for cls in (self, self.device_controller):
180+
for cls in (self, self.device_controller, self.vendor_info):
178181
for attr_name in dir(cls):
179182
if attr_name.startswith("__"):
180183
continue

matter_server/server/vendor_info.py

+97
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
"""Fetches vendor info from the CSA."""
2+
from __future__ import annotations
3+
4+
import logging
5+
from typing import TYPE_CHECKING
6+
7+
from aiohttp import ClientError, ClientSession
8+
9+
from ..common.helpers.api import api_command
10+
from ..common.helpers.util import dataclass_from_dict, dataclass_to_dict
11+
from ..common.models import APICommand, VendorInfo as VendorInfoModel
12+
13+
if TYPE_CHECKING:
14+
from .server import MatterServer
15+
16+
LOGGER = logging.getLogger(__name__)
17+
PRODUCTION_URL = "https://on.dcl.csa-iot.org"
18+
DATA_KEY_VENDOR_INFO = "vendor_info"
19+
20+
21+
class VendorInfo:
22+
"""Fetches vendor info from the CSA and handles api calls to get it."""
23+
24+
def __init__(self, server: MatterServer):
25+
"""Initialize the vendor info."""
26+
self._data: dict[int, VendorInfoModel] = {}
27+
self._server = server
28+
29+
async def start(self) -> None:
30+
"""Async initialize the vendor info."""
31+
self._load_vendors()
32+
await self._fetch_vendors()
33+
self._save_vendors()
34+
35+
def _load_vendors(self) -> None:
36+
"""Load vendor info from storage."""
37+
LOGGER.info("Loading vendor info from storage.")
38+
vendor_count = 0
39+
data = self._server.storage.get(DATA_KEY_VENDOR_INFO, {})
40+
for vendor_id, vendor_info in data.items():
41+
self._data[vendor_id] = dataclass_from_dict(VendorInfoModel, vendor_info)
42+
vendor_count += 1
43+
LOGGER.info("Loaded %s vendors from storage.", vendor_count)
44+
45+
async def _fetch_vendors(self) -> None:
46+
"""Fetch the vendor names from the CSA."""
47+
LOGGER.info("Fetching the latest vendor info from DCL.")
48+
vendors: dict[int, VendorInfoModel] = {}
49+
try:
50+
async with ClientSession(raise_for_status=True) as session:
51+
async with session.get(
52+
f"{PRODUCTION_URL}/dcl/vendorinfo/vendors"
53+
) as response:
54+
data = await response.json()
55+
for vendorinfo in data["vendorInfo"]:
56+
vendors[vendorinfo["vendorID"]] = VendorInfoModel(
57+
vendor_id=vendorinfo["vendorID"],
58+
vendor_name=vendorinfo["vendorName"],
59+
company_legal_name=vendorinfo["companyLegalName"],
60+
company_preferred_name=vendorinfo["companyPreferredName"],
61+
vendor_landing_page_url=vendorinfo["vendorLandingPageURL"],
62+
creator=vendorinfo["creator"],
63+
)
64+
except ClientError as err:
65+
LOGGER.error("Unable to fetch vendor info from DCL: %s", err)
66+
else:
67+
LOGGER.info("Fetched %s vendors from DCL.", len(vendors))
68+
69+
self._data.update(vendors)
70+
71+
def _save_vendors(self) -> None:
72+
"""Save vendor info to storage."""
73+
LOGGER.info("Saving vendor info to storage.")
74+
self._server.storage.set(
75+
DATA_KEY_VENDOR_INFO,
76+
{
77+
vendor_id: dataclass_to_dict(vendor_info)
78+
for vendor_id, vendor_info in self._data.items()
79+
},
80+
)
81+
82+
@api_command(APICommand.GET_VENDOR_NAMES)
83+
async def get_vendor_names(
84+
self, filter_vendors: list[int] | None = None
85+
) -> dict[int, str]:
86+
"""Get a map of vendor ids to vendor names."""
87+
if filter_vendors:
88+
vendors: dict[int, str] = {}
89+
for vendor_id in filter_vendors:
90+
if vendor_id in filter_vendors:
91+
vendors[vendor_id] = self._data[vendor_id].vendor_name
92+
return vendors
93+
94+
return {
95+
vendor_id: vendor_info.vendor_name
96+
for vendor_id, vendor_info in self._data.items()
97+
}

0 commit comments

Comments
 (0)