Skip to content

Commit 6c78bff

Browse files
committed
Fix the dnssd code that browses on both the local and srp domains
- Fixes the UAF issue with the timer - Resolves critical comments from PR project-chip#32631
1 parent 1f2aff4 commit 6c78bff

File tree

3 files changed

+127
-90
lines changed

3 files changed

+127
-90
lines changed

src/platform/Darwin/DnssdContexts.cpp

+47-33
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,7 @@ ResolveContext::ResolveContext(void * cbContext, DnssdResolveCallback cb, chip::
464464
protocol = GetProtocol(cbAddressType);
465465
instanceName = instanceNameToResolve;
466466
consumerCounter = std::move(consumerCounterToUse);
467+
hasSrpTimerStarted = false;
467468
}
468469

469470
ResolveContext::ResolveContext(CommissioningResolveDelegate * delegate, chip::Inet::IPAddressType cbAddressType,
@@ -476,9 +477,15 @@ ResolveContext::ResolveContext(CommissioningResolveDelegate * delegate, chip::In
476477
protocol = GetProtocol(cbAddressType);
477478
instanceName = instanceNameToResolve;
478479
consumerCounter = std::move(consumerCounterToUse);
480+
hasSrpTimerStarted = false;
479481
}
480482

481-
ResolveContext::~ResolveContext() {}
483+
ResolveContext::~ResolveContext() {
484+
if (this->hasSrpTimerStarted)
485+
{
486+
CancelSrpTimer(this);
487+
}
488+
}
482489

483490
void ResolveContext::DispatchFailure(const char * errorStr, CHIP_ERROR err)
484491
{
@@ -526,8 +533,7 @@ void ResolveContext::DispatchSuccess()
526533

527534
for (auto interfaceIndex : priorityInterfaceIndices)
528535
{
529-
// Try finding interfaces for domains kLocalDot and kOpenThreadDot and delete them.
530-
if (TryReportingResultsForInterfaceIndex(static_cast<uint32_t>(interfaceIndex), std::string(kLocalDot)))
536+
if (TryReportingResultsForInterfaceIndex(interfaceIndex))
531537
{
532538
if (needDelete)
533539
{
@@ -536,7 +542,7 @@ void ResolveContext::DispatchSuccess()
536542
return;
537543
}
538544

539-
if (TryReportingResultsForInterfaceIndex(static_cast<uint32_t>(interfaceIndex), std::string(kOpenThreadDot)))
545+
if (TryReportingResultsForInterfaceIndex(interfaceIndex))
540546
{
541547
if (needDelete)
542548
{
@@ -548,7 +554,8 @@ void ResolveContext::DispatchSuccess()
548554

549555
for (auto & interface : interfaces)
550556
{
551-
if (TryReportingResultsForInterfaceIndex(interface.first.first, interface.first.second))
557+
auto interfaceId = interface.first.first;
558+
if (TryReportingResultsForInterfaceIndex(interfaceId))
552559
{
553560
break;
554561
}
@@ -560,52 +567,60 @@ void ResolveContext::DispatchSuccess()
560567
}
561568
}
562569

563-
bool ResolveContext::TryReportingResultsForInterfaceIndex(uint32_t interfaceIndex, std::string domainName)
570+
bool ResolveContext::TryReportingResultsForInterfaceIndex(uint32_t interfaceIndex)
564571
{
565572
if (interfaceIndex == 0)
566573
{
567574
// Not actually an interface we have.
568575
return false;
569576
}
570577

571-
std::pair<uint32_t, std::string> interfaceKey = std::make_pair(interfaceIndex, domainName);
572-
auto & interface = interfaces[interfaceKey];
573-
auto & ips = interface.addresses;
578+
std::map<std::pair<uint32_t, std::string>, InterfaceInfo>::iterator iter = interfaces.begin();
579+
while (iter != interfaces.end())
580+
{
581+
std::pair<uint32_t, std::string> key = iter->first;
582+
if (key.first == interfaceIndex)
583+
{
584+
auto & interface = interfaces[key];
585+
auto & ips = interface.addresses;
574586

575-
// Some interface may not have any ips, just ignore them.
576-
if (ips.size() == 0)
577-
{
578-
return false;
579-
}
587+
// Some interface may not have any ips, just ignore them.
588+
if (ips.size() == 0)
589+
{
590+
return false;
591+
}
580592

581-
ChipLogProgress(Discovery, "Mdns: Resolve success on interface %" PRIu32, interfaceIndex);
593+
ChipLogProgress(Discovery, "Mdns: Resolve success on interface %" PRIu32, interfaceIndex);
582594

583-
auto & service = interface.service;
584-
auto addresses = Span<Inet::IPAddress>(ips.data(), ips.size());
585-
if (nullptr == callback)
586-
{
587-
auto delegate = static_cast<CommissioningResolveDelegate *>(context);
588-
DiscoveredNodeData nodeData;
589-
service.ToDiscoveredNodeData(addresses, nodeData);
590-
delegate->OnNodeDiscovered(nodeData);
591-
}
592-
else
593-
{
594-
callback(context, &service, addresses, CHIP_NO_ERROR);
595-
}
595+
auto & service = interface.service;
596+
auto addresses = Span<Inet::IPAddress>(ips.data(), ips.size());
597+
if (nullptr == callback)
598+
{
599+
auto delegate = static_cast<CommissioningResolveDelegate *>(context);
600+
DiscoveredNodeData nodeData;
601+
service.ToDiscoveredNodeData(addresses, nodeData);
602+
delegate->OnNodeDiscovered(nodeData);
603+
}
604+
else
605+
{
606+
callback(context, &service, addresses, CHIP_NO_ERROR);
607+
}
596608

597-
return true;
609+
return true;
610+
}
611+
}
612+
return false;
598613
}
599614

600-
CHIP_ERROR ResolveContext::OnNewAddress(const std::pair<uint32_t, std::string> interfaceKey, const struct sockaddr * address)
615+
CHIP_ERROR ResolveContext::OnNewAddress(const std::pair<uint32_t, std::string> & interfaceKey, const struct sockaddr * address)
601616
{
602617
// If we don't have any information about this interfaceId, just ignore the
603618
// address, since it won't be usable anyway without things like the port.
604619
// This can happen if "local" is set up as a search domain in the DNS setup
605620
// on the system, because the hostnames we are looking up all end in
606621
// ".local". In other words, we can get regular DNS results in here, not
607622
// just DNS-SD ones.
608-
uint32_t interfaceId = interfaceKey.first;
623+
auto interfaceId = interfaceKey.first;
609624

610625
if (interfaces.find(interfaceKey) == interfaces.end())
611626
{
@@ -720,8 +735,7 @@ void ResolveContext::OnNewInterface(uint32_t interfaceId, const char * fullname,
720735
}
721736

722737
std::pair<uint32_t, std::string> interfaceKey = std::make_pair(interfaceId, domainFromHostname);
723-
724-
interfaces.insert(std::make_pair(interfaceKey, std::move(interface)));
738+
interfaces.insert(std::make_pair(std::move(interfaceKey), std::move(interface)));
725739
}
726740

727741
bool ResolveContext::HasInterface()

src/platform/Darwin/DnssdImpl.cpp

+71-46
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,12 @@ using namespace chip::Dnssd::Internal;
3232

3333
namespace {
3434

35-
// The extra time in milliseconds that we will wait for the resolution on the open thread domain to complete.
36-
constexpr uint16_t kOpenThreadTimeoutInMsec = 250;
35+
constexpr char kLocalDot[] = "local.";
36+
37+
constexpr char kSrpDot[] = "default.service.arpa.";
38+
39+
// The extra time in milliseconds that we will wait for the resolution on the srp domain to complete.
40+
constexpr uint16_t kSrpTimeoutInMsec = 250;
3741

3842
constexpr DNSServiceFlags kRegisterFlags = kDNSServiceFlagsNoAutoRename;
3943
constexpr DNSServiceFlags kBrowseFlags = kDNSServiceFlagsShareConnection;
@@ -146,57 +150,66 @@ std::string GetDomainFromHostName(const char * hostnameWithDomain)
146150

147151
// Find the last occurence of '.'
148152
size_t last_pos = hostname.find_last_of(".");
149-
if (last_pos != std::string::npos)
150-
{
151-
// Get a substring without last '.'
152-
std::string substring = hostname.substr(0, last_pos);
153153

154-
// Find the last occurence of '.' in the substring created above.
155-
size_t pos = substring.find_last_of(".");
156-
if (pos != std::string::npos)
157-
{
158-
// Return the domain name between the last 2 occurences of '.' including the trailing dot'.'.
159-
return std::string(hostname.substr(pos + 1, last_pos));
160-
}
154+
// The last occurence of the dot should be the last character in the hostname with domain.
155+
VerifyOrReturnValue((last_pos != std::string::npos && (last_pos == strlen(hostnameWithDomain) - 1)), std::string());
156+
157+
// Get a substring without last '.'
158+
std::string substring = hostname.substr(0, last_pos);
159+
160+
// Find the last occurence of '.' in the substring created above.
161+
size_t pos = substring.find_last_of(".");
162+
if (pos != std::string::npos)
163+
{
164+
// Return the domain name between the last 2 occurences of '.' including the trailing dot'.'.
165+
return std::string(hostname.substr(pos + 1, last_pos));
161166
}
167+
162168
return std::string();
163169
}
164170

165-
Global<MdnsContexts> MdnsContexts::sInstance;
166-
167-
namespace {
168-
169171
/**
170-
* @brief Callback that is called when the timeout for resolving on the kOpenThreadDot domain has expired.
172+
* @brief Callback that is called when the timeout for resolving on the kSrpDot domain has expired.
171173
*
172174
* @param[in] systemLayer The system layer.
173175
* @param[in] callbackContext The context passed to the timer callback.
174176
*/
175-
void OpenThreadTimerExpiredCallback(System::Layer * systemLayer, void * callbackContext)
177+
void SrpTimerExpiredCallback(System::Layer * systemLayer, void * callbackContext)
176178
{
177-
ChipLogProgress(Discovery, "Mdns: Timer expired for resolve to complete on the open thread domain.");
179+
ChipLogProgress(Discovery, "Mdns: Timer expired for resolve to complete on the srp domain.");
178180
auto sdCtx = static_cast<ResolveContext *>(callbackContext);
179181
VerifyOrDie(sdCtx != nullptr);
180-
181-
if (sdCtx->hasOpenThreadTimerStarted)
182-
{
183-
sdCtx->Finalize();
184-
}
182+
sdCtx->Finalize();
185183
}
186184

187185
/**
188-
* @brief Starts a timer to wait for the resolution on the kOpenThreadDot domain to happen.
186+
* @brief Starts a timer to wait for the resolution on the kSrpDot domain to happen.
189187
*
190188
* @param[in] timeoutSeconds The timeout in seconds.
191189
* @param[in] ResolveContext The resolve context.
192190
*/
193-
void StartOpenThreadTimer(uint16_t timeoutInMSecs, ResolveContext * ctx)
191+
CHIP_ERROR StartSrpTimer(uint16_t timeoutInMSecs, ResolveContext * ctx)
194192
{
195-
VerifyOrReturn(ctx != nullptr, ChipLogError(Discovery, "Can't schedule open thread timer since context is null"));
196-
DeviceLayer::SystemLayer().StartTimer(System::Clock::Milliseconds16(timeoutInMSecs), OpenThreadTimerExpiredCallback,
193+
VerifyOrReturnValue(ctx != nullptr, CHIP_ERROR_INCORRECT_STATE);
194+
return DeviceLayer::SystemLayer().StartTimer(System::Clock::Milliseconds16(timeoutInMSecs), SrpTimerExpiredCallback,
197195
reinterpret_cast<void *>(ctx));
198196
}
199197

198+
/**
199+
* @brief Cancels the timer that was started to wait for the resolution on the kSrpDot domain to happen.
200+
*
201+
* @param[in] ResolveContext The resolve context.
202+
*/
203+
void CancelSrpTimer(ResolveContext * ctx)
204+
{
205+
DeviceLayer::SystemLayer().CancelTimer(SrpTimerExpiredCallback, reinterpret_cast<void *>(ctx));
206+
}
207+
208+
209+
Global<MdnsContexts> MdnsContexts::sInstance;
210+
211+
namespace {
212+
200213
static void OnRegister(DNSServiceRef sdRef, DNSServiceFlags flags, DNSServiceErrorType err, const char * name, const char * type,
201214
const char * domain, void * context)
202215
{
@@ -248,17 +261,17 @@ CHIP_ERROR Browse(BrowseHandler * sdCtx, uint32_t interfaceId, const char * type
248261
auto err = DNSServiceCreateConnection(&sdCtx->serviceRef);
249262
VerifyOrReturnError(kDNSServiceErr_NoError == err, sdCtx->Finalize(err));
250263

251-
// We will browse on both the local domain and the open thread domain.
264+
// We will browse on both the local domain and the srp domain.
252265
ChipLogProgress(Discovery, "Browsing for: %s on domain %s", StringOrNullMarker(type), kLocalDot);
253266

254267
auto sdRefLocal = sdCtx->serviceRef; // Mandatory copy because of kDNSServiceFlagsShareConnection
255268
err = DNSServiceBrowse(&sdRefLocal, kBrowseFlags, interfaceId, type, kLocalDot, OnBrowse, sdCtx);
256269
VerifyOrReturnError(kDNSServiceErr_NoError == err, sdCtx->Finalize(err));
257270

258-
ChipLogProgress(Discovery, "Browsing for: %s on domain %s", StringOrNullMarker(type), kOpenThreadDot);
271+
ChipLogProgress(Discovery, "Browsing for: %s on domain %s", StringOrNullMarker(type), kSrpDot);
259272

260-
DNSServiceRef sdRefOpenThread = sdCtx->serviceRef; // Mandatory copy because of kDNSServiceFlagsShareConnection
261-
err = DNSServiceBrowse(&sdRefOpenThread, kBrowseFlags, interfaceId, type, kOpenThreadDot, OnBrowse, sdCtx);
273+
auto sdRefSrp = sdCtx->serviceRef; // Mandatory copy because of kDNSServiceFlagsShareConnection
274+
err = DNSServiceBrowse(&sdRefSrp, kBrowseFlags, interfaceId, type, kSrpDot, OnBrowse, sdCtx);
262275
VerifyOrReturnError(kDNSServiceErr_NoError == err, sdCtx->Finalize(err));
263276

264277
return MdnsContexts::GetInstance().Add(sdCtx, sdCtx->serviceRef);
@@ -307,25 +320,38 @@ static void OnGetAddrInfo(DNSServiceRef sdRef, DNSServiceFlags flags, uint32_t i
307320
{
308321
VerifyOrReturn(sdCtx->HasAddress(), sdCtx->Finalize(kDNSServiceErr_BadState));
309322

310-
if (domainName.compare(kOpenThreadDot) == 0)
323+
if (domainName.compare(kSrpDot) == 0)
311324
{
312-
ChipLogProgress(Discovery, "Mdns: Resolve completed on the open thread domain.");
325+
ChipLogProgress(Discovery, "Mdns: Resolve completed on the srp domain.");
326+
327+
// Cancel the timer if one has been started
328+
if (sdCtx->hasSrpTimerStarted)
329+
{
330+
CancelSrpTimer(sdCtx);
331+
}
313332
sdCtx->Finalize();
314333
}
315334
else if (domainName.compare(kLocalDot) == 0)
316335
{
317336
ChipLogProgress(
318337
Discovery,
319-
"Mdns: Resolve completed on the local domain. Starting a timer for the open thread resolve to come back");
338+
"Mdns: Resolve completed on the local domain. Starting a timer for the srp resolve to come back");
320339

321-
// Usually the resolution on the local domain is quicker than on the open thread domain. We would like to give the
322-
// resolution on the open thread domain around 250 millisecs more to give it a chance to resolve before finalizing
340+
// Usually the resolution on the local domain is quicker than on the srp domain. We would like to give the
341+
// resolution on the srp domain around 250 millisecs more to give it a chance to resolve before finalizing
323342
// the resolution.
324-
if (!sdCtx->hasOpenThreadTimerStarted)
343+
if (!sdCtx->hasSrpTimerStarted)
325344
{
326-
// Schedule a timer to allow the resolve on OpenThread domain to complete.
327-
StartOpenThreadTimer(kOpenThreadTimeoutInMsec, sdCtx);
328-
sdCtx->hasOpenThreadTimerStarted = true;
345+
// Schedule a timer to allow the resolve on Srp domain to complete.
346+
CHIP_ERROR error = StartSrpTimer(kSrpTimeoutInMsec, sdCtx);
347+
348+
// If the timer fails to start, finalize the context and return.
349+
if (error != CHIP_NO_ERROR)
350+
{
351+
sdCtx->Finalize();
352+
return;
353+
}
354+
sdCtx->hasSrpTimerStarted = true;
329355
}
330356
}
331357
}
@@ -368,7 +394,6 @@ static void OnResolve(DNSServiceRef sdRef, DNSServiceFlags flags, uint32_t inter
368394
{
369395
GetAddrInfo(sdCtx);
370396
sdCtx->isResolveRequested = true;
371-
sdCtx->hasOpenThreadTimerStarted = false;
372397
}
373398
}
374399
}
@@ -382,13 +407,13 @@ static CHIP_ERROR Resolve(ResolveContext * sdCtx, uint32_t interfaceId, chip::In
382407
auto err = DNSServiceCreateConnection(&sdCtx->serviceRef);
383408
VerifyOrReturnError(kDNSServiceErr_NoError == err, sdCtx->Finalize(err));
384409

385-
// Similar to browse, will try to resolve using both the local domain and the open thread domain.
410+
// Similar to browse, will try to resolve using both the local domain and the srp domain.
386411
auto sdRefLocal = sdCtx->serviceRef; // Mandatory copy because of kDNSServiceFlagsShareConnection
387412
err = DNSServiceResolve(&sdRefLocal, kResolveFlags, interfaceId, name, type, kLocalDot, OnResolve, sdCtx);
388413
VerifyOrReturnError(kDNSServiceErr_NoError == err, sdCtx->Finalize(err));
389414

390-
auto sdRefOpenThread = sdCtx->serviceRef; // Mandatory copy because of kDNSServiceFlagsShareConnection
391-
err = DNSServiceResolve(&sdRefOpenThread, kResolveFlags, interfaceId, name, type, kOpenThreadDot, OnResolve, sdCtx);
415+
auto sdRefSrp = sdCtx->serviceRef; // Mandatory copy because of kDNSServiceFlagsShareConnection
416+
err = DNSServiceResolve(&sdRefSrp, kResolveFlags, interfaceId, name, type, kSrpDot, OnResolve, sdCtx);
392417
VerifyOrReturnError(kDNSServiceErr_NoError == err, sdCtx->Finalize(err));
393418

394419
auto retval = MdnsContexts::GetInstance().Add(sdCtx, sdCtx->serviceRef);

0 commit comments

Comments
 (0)