Skip to content

Commit f5ebcf7

Browse files
Adds MdnsAsyncServiceInfo class
1 parent cbb4006 commit f5ebcf7

File tree

1 file changed

+304
-0
lines changed

1 file changed

+304
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,304 @@
1+
#
2+
# Copyright (c) 2024 Project CHIP Authors
3+
# All rights reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
19+
import asyncio
20+
import enum
21+
from ipaddress import IPv4Address, IPv6Address
22+
from random import randint
23+
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union
24+
from zeroconf import BadTypeInNameException, DNSQuestionType, ServiceInfo, Zeroconf, current_time_millis
25+
from zeroconf.const import (
26+
_DNS_HOST_TTL,
27+
_DNS_OTHER_TTL,
28+
_LISTENER_TIME,
29+
_MDNS_PORT,
30+
_DUPLICATE_QUESTION_INTERVAL,
31+
_FLAGS_QR_QUERY,
32+
_CLASS_IN,
33+
_TYPE_A,
34+
_TYPE_AAAA,
35+
_TYPE_SRV,
36+
_TYPE_TXT
37+
)
38+
from zeroconf._dns import (
39+
DNSQuestion,
40+
DNSAddress,
41+
DNSPointer,
42+
DNSQuestionType,
43+
DNSRecord,
44+
DNSService,
45+
DNSText
46+
)
47+
from zeroconf._protocol.outgoing import DNSOutgoing
48+
from zeroconf._utils.name import service_type_name
49+
from zeroconf._utils.net import _encode_address
50+
from zeroconf._cache import DNSCache
51+
from zeroconf._history import QuestionHistory
52+
53+
54+
int_ = int
55+
float_ = float
56+
str_ = str
57+
58+
59+
QU_QUESTION = DNSQuestionType.QU
60+
QM_QUESTION = DNSQuestionType.QM
61+
_AVOID_SYNC_DELAY_RANDOM_INTERVAL = (20, 120)
62+
63+
@enum.unique
64+
class DNSRecordType(enum.Enum):
65+
"""An MDNS record type.
66+
67+
"A" - A MDNS record type
68+
"AAAA" - AAAA MDNS record type
69+
"SRV" - SRV MDNS record type
70+
"TXT" - TXT MDNS record type
71+
"""
72+
73+
A = 0
74+
AAAA = 1
75+
SRV = 2
76+
TXT = 3
77+
78+
79+
class MdnsAsyncServiceInfo(ServiceInfo):
80+
def __init__(
81+
self,
82+
type_: str,
83+
name: str,
84+
port: Optional[int] = None,
85+
weight: int = 0,
86+
priority: int = 0,
87+
properties: Union[bytes, Dict] = b'',
88+
server: Optional[str] = None,
89+
host_ttl: int = _DNS_HOST_TTL,
90+
other_ttl: int = _DNS_OTHER_TTL,
91+
*,
92+
addresses: Optional[List[bytes]] = None,
93+
parsed_addresses: Optional[List[str]] = None,
94+
interface_index: Optional[int] = None,
95+
) -> None:
96+
# Accept both none, or one, but not both.
97+
if addresses is not None and parsed_addresses is not None:
98+
raise TypeError("addresses and parsed_addresses cannot be provided together")
99+
if not type_.endswith(service_type_name(name, strict=False)):
100+
raise BadTypeInNameException
101+
self.interface_index = interface_index
102+
self.text = b''
103+
self.type = type_
104+
self._name = name
105+
self.key = name.lower()
106+
self._ipv4_addresses: List[IPv4Address] = []
107+
self._ipv6_addresses: List[IPv6Address] = []
108+
if addresses is not None:
109+
self.addresses = addresses
110+
elif parsed_addresses is not None:
111+
self.addresses = [_encode_address(a) for a in parsed_addresses]
112+
self.port = port
113+
self.weight = weight
114+
self.priority = priority
115+
self.server = server if server else None
116+
self.server_key = server.lower() if server else None
117+
self._properties: Optional[Dict[bytes, Optional[bytes]]] = None
118+
self._decoded_properties: Optional[Dict[str, Optional[str]]] = None
119+
if isinstance(properties, bytes):
120+
self._set_text(properties)
121+
else:
122+
self._set_properties(properties)
123+
self.host_ttl = host_ttl
124+
self.other_ttl = other_ttl
125+
self._new_records_futures: Optional[Set[asyncio.Future]] = None
126+
self._dns_address_cache: Optional[List[DNSAddress]] = None
127+
self._dns_pointer_cache: Optional[DNSPointer] = None
128+
self._dns_service_cache: Optional[DNSService] = None
129+
self._dns_text_cache: Optional[DNSText] = None
130+
self._get_address_and_nsec_records_cache: Optional[Set[DNSRecord]] = None
131+
132+
async def async_request(
133+
self,
134+
zc: 'Zeroconf',
135+
timeout: float,
136+
question_type: Optional[DNSQuestionType] = None,
137+
addr: Optional[str] = None,
138+
port: int = _MDNS_PORT,
139+
record_type: DNSRecordType = None,
140+
load_from_cache: bool = True
141+
) -> bool:
142+
"""Returns true if the service could be discovered on the
143+
network, and updates this object with details discovered.
144+
145+
This method will be run in the event loop.
146+
147+
Passing addr and port is optional, and will default to the
148+
mDNS multicast address and port. This is useful for directing
149+
requests to a specific host that may be able to respond across
150+
subnets.
151+
"""
152+
if not zc.started:
153+
await zc.async_wait_for_start()
154+
155+
now = current_time_millis()
156+
157+
if load_from_cache:
158+
if self._load_from_cache(zc, now):
159+
return True
160+
161+
if TYPE_CHECKING:
162+
assert zc.loop is not None
163+
164+
first_request = True
165+
delay = self._get_initial_delay()
166+
next_ = now
167+
last = now + timeout
168+
try:
169+
zc.async_add_listener(self, None)
170+
while not self._is_complete:
171+
if last <= now:
172+
return False
173+
if next_ <= now:
174+
this_question_type = question_type or QU_QUESTION if first_request else QM_QUESTION
175+
out = self._generate_request_query(zc, now, this_question_type, record_type)
176+
first_request = False
177+
if out.questions:
178+
# All questions may have been suppressed
179+
# by the question history, so nothing to send,
180+
# but keep waiting for answers in case another
181+
# client on the network is asking the same
182+
# question or they have not arrived yet.
183+
zc.async_send(out, addr, port)
184+
next_ = now + delay
185+
next_ += self._get_random_delay()
186+
if this_question_type is QM_QUESTION and delay < _DUPLICATE_QUESTION_INTERVAL:
187+
# If we just asked a QM question, we need to
188+
# wait at least the duplicate question interval
189+
# before asking another QM question otherwise
190+
# its likely to be suppressed by the question
191+
# history of the remote responder.
192+
delay = _DUPLICATE_QUESTION_INTERVAL
193+
194+
await self.async_wait(min(next_, last) - now, zc.loop)
195+
now = current_time_millis()
196+
finally:
197+
zc.async_remove_listener(self)
198+
199+
return True
200+
201+
def _generate_request_query(
202+
self, zc: 'Zeroconf', now: float_, question_type: DNSQuestionType, record_type: DNSRecordType
203+
) -> DNSOutgoing:
204+
"""Generate the request query."""
205+
out = DNSOutgoing(_FLAGS_QR_QUERY)
206+
name = self._name
207+
server = self.server or name
208+
cache = zc.cache
209+
history = zc.question_history
210+
qu_question = question_type is QU_QUESTION
211+
if record_type is None or record_type is DNSRecordType.SRV:
212+
print("Requesting MDNS SRV record...")
213+
self._add_question_with_known_answers(
214+
out, qu_question, history, cache, now, name, _TYPE_SRV, _CLASS_IN, True
215+
)
216+
if record_type is None or record_type is DNSRecordType.TXT:
217+
print("Requesting MDNS TXT record...")
218+
self._add_question_with_known_answers(
219+
out, qu_question, history, cache, now, name, _TYPE_TXT, _CLASS_IN, True
220+
)
221+
if record_type is None or record_type is DNSRecordType.A:
222+
print("Requesting MDNS A record...")
223+
self._add_question_with_known_answers(
224+
out, qu_question, history, cache, now, server, _TYPE_A, _CLASS_IN, False
225+
)
226+
if record_type is None or record_type is DNSRecordType.AAAA:
227+
print("Requesting MDNS AAAA record...")
228+
self._add_question_with_known_answers(
229+
out, qu_question, history, cache, now, server, _TYPE_AAAA, _CLASS_IN, False
230+
)
231+
return out
232+
233+
def _add_question_with_known_answers(
234+
self,
235+
out: DNSOutgoing,
236+
qu_question: bool,
237+
question_history: QuestionHistory,
238+
cache: DNSCache,
239+
now: float_,
240+
name: str_,
241+
type_: int_,
242+
class_: int_,
243+
skip_if_known_answers: bool,
244+
) -> None:
245+
"""Add a question with known answers if its not suppressed."""
246+
known_answers = {
247+
answer for answer in cache.get_all_by_details(name, type_, class_) if not answer.is_stale(now)
248+
}
249+
if skip_if_known_answers and known_answers:
250+
return
251+
question = DNSQuestion(name, type_, class_)
252+
if qu_question:
253+
question.unicast = True
254+
elif question_history.suppresses(question, now, known_answers):
255+
return
256+
else:
257+
question_history.add_question_at_time(question, now, known_answers)
258+
out.add_question(question)
259+
for answer in known_answers:
260+
out.add_answer_at_time(answer, now)
261+
262+
def _get_initial_delay(self) -> float_:
263+
return _LISTENER_TIME
264+
265+
def _get_random_delay(self) -> int_:
266+
return randint(*_AVOID_SYNC_DELAY_RANDOM_INTERVAL)
267+
268+
def _set_properties(self, properties: Dict[Union[str, bytes], Optional[Union[str, bytes]]]) -> None:
269+
"""Sets properties and text of this info from a dictionary"""
270+
list_: List[bytes] = []
271+
properties_contain_str = False
272+
result = b''
273+
for key, value in properties.items():
274+
if isinstance(key, str):
275+
key = key.encode('utf-8')
276+
properties_contain_str = True
277+
278+
record = key
279+
if value is not None:
280+
if not isinstance(value, bytes):
281+
value = str(value).encode('utf-8')
282+
properties_contain_str = True
283+
record += b'=' + value
284+
list_.append(record)
285+
for item in list_:
286+
result = b''.join((result, bytes((len(item),)), item))
287+
if not properties_contain_str:
288+
# If there are no str keys or values, we can use the properties
289+
# as-is, without decoding them, otherwise calling
290+
# self.properties will lazy decode them, which is expensive.
291+
if TYPE_CHECKING:
292+
self._properties = cast("Dict[bytes, Optional[bytes]]", properties)
293+
else:
294+
self._properties = properties
295+
self.text = result
296+
297+
def _set_text(self, text: bytes) -> None:
298+
"""Sets properties and text given a text field"""
299+
if text == self.text:
300+
return
301+
self.text = text
302+
# Clear the properties cache
303+
self._properties = None
304+
self._decoded_properties = None

0 commit comments

Comments
 (0)