Skip to content

Commit 5dc0419

Browse files
Adds get_service_by_record_type to mdns class
1 parent f5ebcf7 commit 5dc0419

File tree

2 files changed

+66
-76
lines changed

2 files changed

+66
-76
lines changed

src/python_testing/mdns_discovery/mdns_discovery.py

+66-50
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,13 @@
2020
import json
2121
from dataclasses import asdict, dataclass
2222
from enum import Enum
23+
from time import sleep
2324
from typing import Dict, List, Optional
2425

25-
from zeroconf import IPVersion, ServiceListener, ServiceStateChange, Zeroconf, DNSRecordType
26+
from zeroconf import IPVersion, ServiceListener, ServiceStateChange, Zeroconf
2627
from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconfServiceTypes
2728

29+
from mdns_discovery.mdns_async_service_info import DNSRecordType, MdnsAsyncServiceInfo
2830

2931
@dataclass
3032
class MdnsServiceInfo:
@@ -84,6 +86,7 @@ def __init__(self):
8486
self.updated_event = asyncio.Event()
8587

8688
def add_service(self, zeroconf: Zeroconf, service_type: str, name: str) -> None:
89+
sleep(0.5)
8790
self.updated_event.set()
8891

8992
def remove_service(self, zeroconf: Zeroconf, service_type: str, name: str) -> None:
@@ -151,66 +154,20 @@ async def get_commissionable_service(self, log_output: bool = False,
151154
"""
152155
return await self._get_service(MdnsServiceType.COMMISSIONABLE, log_output, discovery_timeout_sec)
153156

154-
async def get_operational_service(self, service_name: str = None,
155-
service_type: str = None,
156-
discovery_timeout_sec: float = DISCOVERY_TIMEOUT_SEC,
157-
log_output: bool = False
157+
async def get_operational_service(self, log_output: bool = False,
158+
discovery_timeout_sec: float = DISCOVERY_TIMEOUT_SEC
158159
) -> Optional[MdnsServiceInfo]:
159160
"""
160161
Asynchronously discovers an operational mDNS service within the network.
161162
162163
Args:
163164
log_output (bool): Logs the discovered services to the console. Defaults to False.
164165
discovery_timeout_sec (float): Defaults to 15 seconds.
165-
service_name (str): The unique name of the mDNS service. Defaults to None.
166-
service_type (str): The service type of the service. Defaults to None.
167166
168167
Returns:
169168
Optional[MdnsServiceInfo]: An instance of MdnsServiceInfo or None if timeout reached.
170169
"""
171-
# Validation to ensure both or none of the parameters are provided
172-
if (service_name is None) != (service_type is None):
173-
raise ValueError("Both service_name and service_type must be provided together or not at all.")
174-
175-
mdns_service_info = None
176-
177-
if service_name is None and service_type is None:
178-
mdns_service_info = await self._get_service(MdnsServiceType.OPERATIONAL, log_output, discovery_timeout_sec)
179-
else:
180-
print(f"Looking for MDNS service type '{service_type}', service name '{service_name}'")
181-
182-
# Adds service listener
183-
service_listener = MdnsServiceListener()
184-
self._zc.add_service_listener(MdnsServiceType.OPERATIONAL.value, service_listener)
185-
186-
# Wait for the add/update service event or timeout
187-
try:
188-
await asyncio.wait_for(service_listener.updated_event.wait(), discovery_timeout_sec)
189-
except asyncio.TimeoutError:
190-
print(f"Service lookup for {service_name} timeout ({discovery_timeout_sec}) reached without an update.")
191-
finally:
192-
self._zc.remove_service_listener(service_listener)
193-
194-
# Get service info
195-
service_info = AsyncServiceInfo(service_type, service_name)
196-
is_discovered = await service_info.async_request(
197-
self._zc,
198-
3000,
199-
record_type=DNSRecordType.A,
200-
load_from_cache=False)
201-
202-
# Adds service to discovered services
203-
if is_discovered:
204-
mdns_service_info = self._to_mdns_service_info_class(service_info)
205-
self._discovered_services = {}
206-
self._discovered_services[service_type] = []
207-
if mdns_service_info is not None:
208-
self._discovered_services[service_type].append(mdns_service_info)
209-
210-
if log_output:
211-
self._log_output()
212-
213-
return mdns_service_info
170+
return await self._get_service(MdnsServiceType.OPERATIONAL, log_output, discovery_timeout_sec)
214171

215172
async def get_border_router_service(self, log_output: bool = False,
216173
discovery_timeout_sec: float = DISCOVERY_TIMEOUT_SEC
@@ -268,6 +225,65 @@ async def get_service_types(self, log_output: bool = False) -> List[str]:
268225

269226
return discovered_services
270227

228+
async def get_service_by_record_type(self, service_name: str,
229+
service_type: str,
230+
record_type: DNSRecordType,
231+
load_from_cache: bool = True,
232+
discovery_timeout_sec: float = DISCOVERY_TIMEOUT_SEC,
233+
log_output: bool = False
234+
) -> Optional[MdnsServiceInfo]:
235+
"""
236+
Asynchronously discovers an mDNS service within the network by service name, service type,
237+
and record type.
238+
239+
Args:
240+
log_output (bool): Logs the discovered services to the console. Defaults to False.
241+
discovery_timeout_sec (float): Defaults to 15 seconds.
242+
service_name (str): The unique name of the mDNS service. Defaults to None.
243+
service_type (str): The service type of the service. Defaults to None.
244+
record_type (DNSRecordType): The type of record to look for (SRV, TXT, AAAA, A).
245+
246+
Returns:
247+
Optional[MdnsServiceInfo]: An instance of MdnsServiceInfo or None if timeout reached.
248+
"""
249+
mdns_service_info = None
250+
251+
print(
252+
f"Looking for MDNS service type '{service_type}', service name '{service_name}', record type '{record_type.name}'")
253+
254+
# Adds service listener
255+
service_listener = MdnsServiceListener()
256+
self._zc.add_service_listener(MdnsServiceType.OPERATIONAL.value, service_listener)
257+
258+
# Wait for the add/update service event or timeout
259+
try:
260+
await asyncio.wait_for(service_listener.updated_event.wait(), discovery_timeout_sec)
261+
except asyncio.TimeoutError:
262+
print(f"Service lookup for {service_name} timeout ({discovery_timeout_sec}) reached without an update.")
263+
finally:
264+
self._zc.remove_service_listener(service_listener)
265+
266+
# Get service info
267+
service_info = MdnsAsyncServiceInfo(service_type, service_name)
268+
is_discovered = await service_info.async_request(
269+
self._zc,
270+
3000,
271+
record_type=record_type,
272+
load_from_cache=load_from_cache)
273+
274+
# Adds service to discovered services
275+
if is_discovered:
276+
mdns_service_info = self._to_mdns_service_info_class(service_info)
277+
self._discovered_services = {}
278+
self._discovered_services[service_type] = []
279+
if mdns_service_info is not None:
280+
self._discovered_services[service_type].append(mdns_service_info)
281+
282+
if log_output:
283+
self._log_output()
284+
285+
return mdns_service_info
286+
271287
# Private methods
272288
async def _discover(self,
273289
discovery_timeout_sec: float,

src/python_testing/mdns_discovery/mdns_service_type_enum.py

-26
This file was deleted.

0 commit comments

Comments
 (0)