Skip to content

Commit 20e4359

Browse files
Adds MdnsDiscovery class (project-chip#31645)
* Adds MdnsDiscovery class * Fix restyle/lint * Fix restyle * Adds zeroconf dependency in tests.yaml * Relocates zeroconf dependency from tests.yaml to requirements.txt * Addresses latest review comments * Fixes typo * Updates instantiation method and initial discovery * Fix restyle/lint * Addresses latest review comments * Addresses latest review comments * Fix restlye/lint * Addresses review comments * restore enum * Refactor progress * Major refactor to discover() and get_operational_service_info, pending other get methods * Fix restyle * Fix restyle/lint * Updates descriptions and variables * Major refactor #2 * Fix restyle * Updates method descriptions * Fix restyle * Addresses review comments * Fix restyle * Fix lint * Default parameters adjustment --------- Co-authored-by: C Freeman <cecille@google.com>
1 parent 1259831 commit 20e4359

File tree

3 files changed

+402
-0
lines changed

3 files changed

+402
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,375 @@
1+
#
2+
# Copyright (c) 2024 Project CHIP Authors
3+
# All rights reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
19+
import asyncio
20+
import json
21+
from dataclasses import asdict, dataclass
22+
from enum import Enum
23+
from typing import Dict, List, Optional
24+
25+
from zeroconf import IPVersion, ServiceStateChange, Zeroconf
26+
from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconfServiceTypes
27+
28+
29+
@dataclass
30+
class MdnsServiceInfo:
31+
# The unique name of the mDNS service.
32+
service_name: str
33+
34+
# The service type of the service, typically indicating the service protocol and domain.
35+
service_type: str
36+
37+
# The instance name of the service.
38+
instance_name: str
39+
40+
# The domain name of the machine hosting the service.
41+
server: str
42+
43+
# The network port on which the service is available.
44+
port: int
45+
46+
# A list of IP addresses associated with the service.
47+
addresses: list[str]
48+
49+
# A dictionary of key-value pairs representing the service's metadata.
50+
txt_record: dict[str, str]
51+
52+
# The priority of the service, used in service selection when multiple instances are available.
53+
priority: int
54+
55+
# The network interface index on which the service is advertised.
56+
interface_index: int
57+
58+
# The relative weight for records with the same priority, used in load balancing.
59+
weight: int
60+
61+
# The time-to-live value for the host name in the DNS record.
62+
host_ttl: int
63+
64+
# The time-to-live value for other records associated with the service.
65+
other_ttl: int
66+
67+
68+
class MdnsServiceType(Enum):
69+
"""
70+
Enum for Matter mDNS service types used in network service discovery.
71+
"""
72+
COMMISSIONER = "_matterd._udp.local."
73+
COMMISSIONABLE = "_matterc._udp.local."
74+
OPERATIONAL = "_matter._tcp.local."
75+
BORDER_ROUTER = "_meshcop._udp.local."
76+
77+
78+
class MdnsDiscovery:
79+
80+
DISCOVERY_TIMEOUT_SEC = 15
81+
82+
def __init__(self):
83+
"""
84+
Initializes the MdnsDiscovery instance.
85+
86+
Main methods:
87+
- get_commissioner_service
88+
- get_commissionable_service
89+
- get_operational_service
90+
- get_border_router_service
91+
- get_all_services
92+
"""
93+
# An instance of Zeroconf to manage mDNS operations.
94+
self._zc = Zeroconf(ip_version=IPVersion.V6Only)
95+
96+
# A dictionary to store discovered services.
97+
self._discovered_services = {}
98+
99+
# A list of service types
100+
self._service_types = []
101+
102+
# An asyncio Event to signal when a service has been discovered
103+
self._event = asyncio.Event()
104+
105+
# Public methods
106+
async def get_commissioner_service(self, log_output: bool = False,
107+
discovery_timeout_sec: float = DISCOVERY_TIMEOUT_SEC
108+
) -> Optional[MdnsServiceInfo]:
109+
"""
110+
Asynchronously discovers a commissioner mDNS service within the network.
111+
112+
Args:
113+
log_output (bool): Logs the discovered services to the console. Defaults to False.
114+
discovery_timeout_sec (float): Defaults to 15 seconds.
115+
116+
Returns:
117+
Optional[MdnsServiceInfo]: An instance of MdnsServiceInfo or None if timeout reached.
118+
"""
119+
return await self._get_service(MdnsServiceType.COMMISSIONER, log_output, discovery_timeout_sec)
120+
121+
async def get_commissionable_service(self, log_output: bool = False,
122+
discovery_timeout_sec: float = DISCOVERY_TIMEOUT_SEC
123+
) -> Optional[MdnsServiceInfo]:
124+
"""
125+
Asynchronously discovers a commissionable mDNS service within the network.
126+
127+
Args:
128+
log_output (bool): Logs the discovered services to the console. Defaults to False.
129+
discovery_timeout_sec (float): Defaults to 15 seconds.
130+
131+
Returns:
132+
Optional[MdnsServiceInfo]: An instance of MdnsServiceInfo or None if timeout reached.
133+
"""
134+
return await self._get_service(MdnsServiceType.COMMISSIONABLE, log_output, discovery_timeout_sec)
135+
136+
async def get_operational_service(self, service_name: str = None,
137+
service_type: str = None,
138+
discovery_timeout_sec: float = DISCOVERY_TIMEOUT_SEC,
139+
log_output: bool = False
140+
) -> Optional[MdnsServiceInfo]:
141+
"""
142+
Asynchronously discovers an operational mDNS service within the network.
143+
144+
Args:
145+
log_output (bool): Logs the discovered services to the console. Defaults to False.
146+
discovery_timeout_sec (float): Defaults to 15 seconds.
147+
service_name (str): The unique name of the mDNS service. Defaults to None.
148+
service_type (str): The service type of the service. Defaults to None.
149+
150+
Returns:
151+
Optional[MdnsServiceInfo]: An instance of MdnsServiceInfo or None if timeout reached.
152+
"""
153+
# Validation to ensure both or none of the parameters are provided
154+
if (service_name is None) != (service_type is None):
155+
raise ValueError("Both service_name and service_type must be provided together or not at all.")
156+
157+
mdns_service_info = None
158+
159+
if service_name is None and service_type is None:
160+
mdns_service_info = await self._get_service(MdnsServiceType.OPERATIONAL, log_output, discovery_timeout_sec)
161+
else:
162+
print(f"Looking for MDNS service type '{service_type}', service name '{service_name}'")
163+
164+
# Get service info
165+
service_info = AsyncServiceInfo(service_type, service_name)
166+
is_discovered = await service_info.async_request(self._zc, 3000)
167+
if is_discovered:
168+
mdns_service_info = self._to_mdns_service_info_class(service_info)
169+
self._discovered_services = {}
170+
self._discovered_services[service_type] = [mdns_service_info]
171+
172+
if log_output:
173+
self._log_output()
174+
175+
return mdns_service_info
176+
177+
async def get_border_router_service(self, log_output: bool = False,
178+
discovery_timeout_sec: float = DISCOVERY_TIMEOUT_SEC
179+
) -> Optional[MdnsServiceInfo]:
180+
"""
181+
Asynchronously discovers a border router mDNS service within the network.
182+
183+
Args:
184+
log_output (bool): Logs the discovered services to the console. Defaults to False.
185+
discovery_timeout_sec (float): Defaults to 15 seconds.
186+
187+
Returns:
188+
Optional[MdnsServiceInfo]: An instance of MdnsServiceInfo or None if timeout reached.
189+
"""
190+
return await self._get_service(MdnsServiceType.BORDER_ROUTER, log_output, discovery_timeout_sec)
191+
192+
async def get_all_services(self, log_output: bool = False,
193+
discovery_timeout_sec: float = DISCOVERY_TIMEOUT_SEC
194+
) -> Dict[str, List[MdnsServiceInfo]]:
195+
"""
196+
Asynchronously discovers all available mDNS services within the network.
197+
198+
Args:
199+
log_output (bool): Logs the discovered services to the console. Defaults to False.
200+
discovery_timeout_sec (float): Defaults to 15 seconds.
201+
202+
Returns:
203+
Dict[str, List[MdnsServiceInfo]]: A dictionary mapping service types (str) to
204+
lists of MdnsServiceInfo objects.
205+
"""
206+
await self._discover(discovery_timeout_sec, log_output, all_services=True)
207+
208+
return self._discovered_services
209+
210+
# Private methods
211+
async def _discover(self,
212+
discovery_timeout_sec: float,
213+
log_output: bool,
214+
all_services: bool = False
215+
) -> None:
216+
"""
217+
Asynchronously discovers network services using multicast DNS (mDNS).
218+
219+
Args:
220+
discovery_timeout_sec (float): The duration in seconds to wait for the discovery process, allowing for service
221+
announcements to be collected.
222+
all_services (bool): If True, discovers all available mDNS services. If False, discovers services based on the
223+
predefined `_service_types` list. Defaults to False.
224+
log_output (bool): If True, logs the discovered services to the console in JSON format for debugging or informational
225+
purposes. Defaults to False.
226+
227+
Returns:
228+
None: This method does not return any value.
229+
230+
Note:
231+
The discovery duration may need to be adjusted based on network conditions and expected response times for service
232+
announcements. The method leverages an asyncio event to manage asynchronous waiting and cancellation based on discovery
233+
success or timeout.
234+
"""
235+
self._event.clear()
236+
237+
if all_services:
238+
self._service_types = list(await AsyncZeroconfServiceTypes.async_find())
239+
240+
print(f"Browsing for MDNS service(s) of type: {self._service_types}")
241+
242+
aiobrowser = AsyncServiceBrowser(zeroconf=self._zc,
243+
type_=self._service_types,
244+
handlers=[self._on_service_state_change]
245+
)
246+
247+
try:
248+
await asyncio.wait_for(self._event.wait(), timeout=discovery_timeout_sec)
249+
except asyncio.TimeoutError:
250+
print(f"MDNS service discovery timed out after {discovery_timeout_sec} seconds.")
251+
finally:
252+
await aiobrowser.async_cancel()
253+
254+
if log_output:
255+
self._log_output()
256+
257+
def _on_service_state_change(
258+
self,
259+
zeroconf: Zeroconf,
260+
service_type: str,
261+
name: str,
262+
state_change: ServiceStateChange
263+
) -> None:
264+
"""
265+
Callback method triggered on mDNS service state change.
266+
267+
This method is called by the Zeroconf library when there is a change in the state of an mDNS service.
268+
It handles the addition of new services by initiating a query for their detailed information.
269+
270+
Args:
271+
zeroconf (Zeroconf): The Zeroconf instance managing the network operations.
272+
service_type (str): The service type of the mDNS service that changed state.
273+
name (str): The service name of the mDNS service.
274+
state_change (ServiceStateChange): The type of state change that occurred.
275+
276+
Returns:
277+
None: This method does not return any value.
278+
"""
279+
if state_change.value == ServiceStateChange.Added.value:
280+
self._event.set()
281+
asyncio.ensure_future(self._query_service_info(
282+
zeroconf,
283+
service_type,
284+
name)
285+
)
286+
287+
async def _query_service_info(self, zeroconf: Zeroconf, service_type: str, service_name: str) -> None:
288+
"""
289+
This method queries for service details such as its address, port, and TXT records
290+
containing metadata.
291+
292+
Args:
293+
zeroconf (Zeroconf): The Zeroconf instance used for managing network operations and service discovery.
294+
service_type (str): The type of the mDNS service being queried.
295+
service_name (str): The specific service name of the mDNS service to query. This service name uniquely
296+
identifies the service instance within the local network.
297+
298+
Returns:
299+
None: This method does not return any value.
300+
"""
301+
# Get service info
302+
service_info = AsyncServiceInfo(service_type, service_name)
303+
is_service_discovered = await service_info.async_request(zeroconf, 3000)
304+
service_info.async_clear_cache()
305+
306+
if is_service_discovered:
307+
mdns_service_info = self._to_mdns_service_info_class(service_info)
308+
309+
if service_type not in self._discovered_services:
310+
self._discovered_services[service_type] = [mdns_service_info]
311+
else:
312+
self._discovered_services[service_type].append(mdns_service_info)
313+
314+
def _to_mdns_service_info_class(self, service_info: AsyncServiceInfo) -> MdnsServiceInfo:
315+
"""
316+
Converts an AsyncServiceInfo object into a MdnsServiceInfo data class.
317+
318+
Args:
319+
service_info (AsyncServiceInfo): The service information to convert.
320+
321+
Returns:
322+
MdnsServiceInfo: The converted service information as a data class.
323+
"""
324+
mdns_service_info = MdnsServiceInfo(
325+
service_name=service_info.name,
326+
service_type=service_info.type,
327+
instance_name=service_info.get_name(),
328+
server=service_info.server,
329+
port=service_info.port,
330+
addresses=service_info.parsed_addresses(),
331+
txt_record=service_info.decoded_properties,
332+
priority=service_info.priority,
333+
interface_index=service_info.interface_index,
334+
weight=service_info.weight,
335+
host_ttl=service_info.host_ttl,
336+
other_ttl=service_info.other_ttl
337+
)
338+
339+
return mdns_service_info
340+
341+
async def _get_service(self, service_type: MdnsServiceType,
342+
log_output: bool,
343+
discovery_timeout_sec: float
344+
) -> Optional[MdnsServiceInfo]:
345+
"""
346+
Asynchronously discovers a specific type of mDNS service within the network and returns its details.
347+
348+
Args:
349+
service_type (MdnsServiceType): The enum representing the type of mDNS service to discover.
350+
log_output (bool): Logs the discovered services to the console. Defaults to False.
351+
discovery_timeout_sec (float): Defaults to 15 seconds.
352+
353+
Returns:
354+
Optional[MdnsServiceInfo]: An instance of MdnsServiceInfo representing the discovered service, if
355+
any. Returns None if no service of the specified type is discovered within
356+
the timeout period.
357+
"""
358+
mdns_service_info = None
359+
self._service_types = [service_type.value]
360+
await self._discover(discovery_timeout_sec, log_output)
361+
if service_type.value in self._discovered_services:
362+
mdns_service_info = self._discovered_services[service_type.value][0]
363+
364+
return mdns_service_info
365+
366+
def _log_output(self) -> str:
367+
"""
368+
Converts the discovered services to a JSON string and prints it.
369+
370+
The method is intended to be used for debugging or informational purposes, providing a clear and
371+
comprehensive view of all services discovered during the mDNS service discovery process.
372+
"""
373+
converted_services = {key: [asdict(item) for item in value] for key, value in self._discovered_services.items()}
374+
json_str = json.dumps(converted_services, indent=4)
375+
print(json_str)

0 commit comments

Comments
 (0)