Skip to content

Commit dcb4444

Browse files
[Python] fixed subscription crash (#32257)
* [Python] After SubscriptionTransaction has an error, calling Shutdown() will crash * Add a comment * update comment
1 parent 4df081c commit dcb4444

File tree

2 files changed

+96
-55
lines changed

2 files changed

+96
-55
lines changed

src/controller/python/chip/clusters/Attribute.py

+80-46
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,11 @@ def __post_init__(self):
110110
'''Only one of either ClusterType and AttributeType OR Path may be provided.'''
111111

112112
if (self.ClusterType is not None and self.AttributeType is not None) and self.Path is not None:
113-
raise ValueError("Only one of either ClusterType and AttributeType OR Path may be provided.")
113+
raise ValueError(
114+
"Only one of either ClusterType and AttributeType OR Path may be provided.")
114115
if (self.ClusterType is None or self.AttributeType is None) and self.Path is None:
115-
raise ValueError("Either ClusterType and AttributeType OR Path must be provided.")
116+
raise ValueError(
117+
"Either ClusterType and AttributeType OR Path must be provided.")
116118

117119
# if ClusterType and AttributeType were provided we can continue onwards to deriving the label.
118120
# Otherwise, we'll need to walk the attribute index to find the right type information.
@@ -373,7 +375,8 @@ def handle_cluster_view(endpointId, clusterId, clusterType):
373375
try:
374376
decodedData = clusterType.FromDict(
375377
data=clusterType.descriptor.TagDictToLabelDict([], self.attributeTLVCache[endpointId][clusterId]))
376-
decodedData.SetDataVersion(self.versionList.get(endpointId, {}).get(clusterId))
378+
decodedData.SetDataVersion(
379+
self.versionList.get(endpointId, {}).get(clusterId))
377380
return decodedData
378381
except Exception as ex:
379382
return ValueDecodeFailure(self.attributeTLVCache[endpointId][clusterId], ex)
@@ -404,12 +407,14 @@ def handle_attribute_view(endpointId, clusterId, attributeId, attributeType):
404407
clusterType = _ClusterIndex[clusterId]
405408

406409
if self.returnClusterObject:
407-
endpointCache[clusterType] = handle_cluster_view(endpointId, clusterId, clusterType)
410+
endpointCache[clusterType] = handle_cluster_view(
411+
endpointId, clusterId, clusterType)
408412
else:
409413
if clusterType not in endpointCache:
410414
endpointCache[clusterType] = {}
411415
clusterCache = endpointCache[clusterType]
412-
clusterCache[DataVersion] = self.versionList.get(endpointId, {}).get(clusterId)
416+
clusterCache[DataVersion] = self.versionList.get(
417+
endpointId, {}).get(clusterId)
413418

414419
if (clusterId, attributeId) not in _AttributeIndex:
415420
#
@@ -419,7 +424,8 @@ def handle_attribute_view(endpointId, clusterId, attributeId, attributeType):
419424
continue
420425

421426
attributeType = _AttributeIndex[(clusterId, attributeId)][0]
422-
clusterCache[attributeType] = handle_attribute_view(endpointId, clusterId, attributeId, attributeType)
427+
clusterCache[attributeType] = handle_attribute_view(
428+
endpointId, clusterId, attributeId, attributeType)
423429
self._attributeCacheUpdateNeeded.clear()
424430
return self._attributeCache
425431

@@ -428,14 +434,18 @@ class SubscriptionTransaction:
428434
def __init__(self, transaction: AsyncReadTransaction, subscriptionId, devCtrl):
429435
self._onResubscriptionAttemptedCb: Callable[[SubscriptionTransaction,
430436
int, int], None] = DefaultResubscriptionAttemptedCallback
431-
self._onAttributeChangeCb: Callable[[TypedAttributePath, SubscriptionTransaction], None] = DefaultAttributeChangeCallback
432-
self._onEventChangeCb: Callable[[EventReadResult, SubscriptionTransaction], None] = DefaultEventChangeCallback
433-
self._onErrorCb: Callable[[int, SubscriptionTransaction], None] = DefaultErrorCallback
437+
self._onAttributeChangeCb: Callable[[
438+
TypedAttributePath, SubscriptionTransaction], None] = DefaultAttributeChangeCallback
439+
self._onEventChangeCb: Callable[[
440+
EventReadResult, SubscriptionTransaction], None] = DefaultEventChangeCallback
441+
self._onErrorCb: Callable[[
442+
int, SubscriptionTransaction], None] = DefaultErrorCallback
434443
self._readTransaction = transaction
435444
self._subscriptionId = subscriptionId
436445
self._devCtrl = devCtrl
437446
self._isDone = False
438-
self._onResubscriptionSucceededCb: Optional[Callable[[SubscriptionTransaction], None]] = None
447+
self._onResubscriptionSucceededCb: Optional[Callable[[
448+
SubscriptionTransaction], None]] = None
439449
self._onResubscriptionSucceededCb_isAsync = False
440450
self._onResubscriptionAttemptedCb_isAsync = False
441451

@@ -460,7 +470,8 @@ def GetEvents(self):
460470
def OverrideLivenessTimeoutMs(self, timeoutMs: int):
461471
handle = chip.native.GetLibraryHandle()
462472
builtins.chipStack.Call(
463-
lambda: handle.pychip_ReadClient_OverrideLivenessTimeout(self._readTransaction._pReadClient, timeoutMs)
473+
lambda: handle.pychip_ReadClient_OverrideLivenessTimeout(
474+
self._readTransaction._pReadClient, timeoutMs)
464475
)
465476

466477
async def TriggerResubscribeIfScheduled(self, reason: str):
@@ -501,7 +512,8 @@ def GetSubscriptionTimeoutMs(self) -> int:
501512
timeoutMs = ctypes.c_uint32(0)
502513
handle = chip.native.GetLibraryHandle()
503514
builtins.chipStack.Call(
504-
lambda: handle.pychip_ReadClient_GetSubscriptionTimeoutMs(self._readTransaction._pReadClient, ctypes.pointer(timeoutMs))
515+
lambda: handle.pychip_ReadClient_GetSubscriptionTimeoutMs(
516+
self._readTransaction._pReadClient, ctypes.pointer(timeoutMs))
505517
)
506518
return timeoutMs.value
507519

@@ -567,13 +579,14 @@ def subscriptionId(self) -> int:
567579

568580
def Shutdown(self):
569581
if (self._isDone):
570-
LOGGER.warning("Subscription 0x%08x was already terminated previously!", self.subscriptionId)
582+
LOGGER.warning(
583+
"Subscription 0x%08x was already terminated previously!", self.subscriptionId)
571584
return
572585

573586
handle = chip.native.GetLibraryHandle()
574587
builtins.chipStack.Call(
575-
lambda: handle.pychip_ReadClient_Abort(
576-
self._readTransaction._pReadClient, self._readTransaction._pReadCallback))
588+
lambda: handle.pychip_ReadClient_ShutdownSubscription(
589+
self._readTransaction._pReadClient))
577590
self._isDone = True
578591

579592
def __del__(self):
@@ -585,7 +598,8 @@ def __repr__(self):
585598

586599
def DefaultResubscriptionAttemptedCallback(transaction: SubscriptionTransaction,
587600
terminationError, nextResubscribeIntervalMsec):
588-
print(f"Previous subscription failed with Error: {terminationError} - re-subscribing in {nextResubscribeIntervalMsec}ms...")
601+
print(
602+
f"Previous subscription failed with Error: {terminationError} - re-subscribing in {nextResubscribeIntervalMsec}ms...")
589603

590604

591605
def DefaultAttributeChangeCallback(path: TypedAttributePath, transaction: SubscriptionTransaction):
@@ -648,12 +662,10 @@ def __init__(self, future: Future, eventLoop, devCtrl, returnClusterObject: bool
648662
self._cache = AttributeCache(returnClusterObject=returnClusterObject)
649663
self._changedPathSet: Set[AttributePath] = set()
650664
self._pReadClient = None
651-
self._pReadCallback = None
652665
self._resultError: Optional[PyChipError] = None
653666

654-
def SetClientObjPointers(self, pReadClient, pReadCallback):
667+
def SetClientObjPointers(self, pReadClient):
655668
self._pReadClient = pReadClient
656-
self._pReadCallback = pReadCallback
657669

658670
def GetAllEventValues(self):
659671
return self._events
@@ -729,7 +741,8 @@ def handleEventData(self, header: EventHeader, path: EventPath, data: bytes, sta
729741

730742
def handleError(self, chipError: PyChipError):
731743
if self._subscription_handler:
732-
self._subscription_handler.OnErrorCb(chipError.code, self._subscription_handler)
744+
self._subscription_handler.OnErrorCb(
745+
chipError.code, self._subscription_handler)
733746
self._resultError = chipError
734747

735748
def _handleSubscriptionEstablished(self, subscriptionId):
@@ -744,7 +757,8 @@ def _handleSubscriptionEstablished(self, subscriptionId):
744757
self._event_loop.create_task(
745758
self._subscription_handler._onResubscriptionSucceededCb(self._subscription_handler))
746759
else:
747-
self._subscription_handler._onResubscriptionSucceededCb(self._subscription_handler)
760+
self._subscription_handler._onResubscriptionSucceededCb(
761+
self._subscription_handler)
748762

749763
def handleSubscriptionEstablished(self, subscriptionId):
750764
self._event_loop.call_soon_threadsafe(
@@ -820,7 +834,8 @@ def __init__(self, future: Future, eventLoop):
820834
def handleResponse(self, path: AttributePath, status: int):
821835
try:
822836
imStatus = chip.interaction_model.Status(status)
823-
self._resultData.append(AttributeWriteResult(Path=path, Status=imStatus))
837+
self._resultData.append(
838+
AttributeWriteResult(Path=path, Status=imStatus))
824839
except ValueError as ex:
825840
LOGGER.exception(ex)
826841

@@ -835,8 +850,10 @@ def _handleDone(self):
835850
#
836851
if self._resultError is not None:
837852
if self._resultError.sdk_part is ErrorSDKPart.IM_GLOBAL_STATUS:
838-
im_status = chip.interaction_model.Status(self._resultError.sdk_code)
839-
self._future.set_exception(chip.interaction_model.InteractionModelError(im_status))
853+
im_status = chip.interaction_model.Status(
854+
self._resultError.sdk_code)
855+
self._future.set_exception(
856+
chip.interaction_model.InteractionModelError(im_status))
840857
else:
841858
self._future.set_exception(self._resultError.to_exception())
842859
else:
@@ -856,7 +873,8 @@ def handleDone(self):
856873
_OnReadAttributeDataCallbackFunct = CFUNCTYPE(
857874
None, py_object, c_uint32, c_uint16, c_uint32, c_uint32, c_uint8, c_void_p, c_size_t)
858875
_OnSubscriptionEstablishedCallbackFunct = CFUNCTYPE(None, py_object, c_uint32)
859-
_OnResubscriptionAttemptedCallbackFunct = CFUNCTYPE(None, py_object, PyChipError, c_uint32)
876+
_OnResubscriptionAttemptedCallbackFunct = CFUNCTYPE(
877+
None, py_object, PyChipError, c_uint32)
860878
_OnReadEventDataCallbackFunct = CFUNCTYPE(
861879
None, py_object, c_uint16, c_uint32, c_uint32, c_uint64, c_uint8, c_uint64, c_uint8, c_void_p, c_size_t, c_uint8)
862880
_OnReadErrorCallbackFunct = CFUNCTYPE(
@@ -897,7 +915,8 @@ def _OnSubscriptionEstablishedCallback(closure, subscriptionId):
897915

898916
@_OnResubscriptionAttemptedCallbackFunct
899917
def _OnResubscriptionAttemptedCallback(closure, terminationCause: PyChipError, nextResubscribeIntervalMsec: int):
900-
closure.handleResubscriptionAttempted(terminationCause, nextResubscribeIntervalMsec)
918+
closure.handleResubscriptionAttempted(
919+
terminationCause, nextResubscribeIntervalMsec)
901920

902921

903922
@_OnReadErrorCallbackFunct
@@ -954,25 +973,34 @@ def WriteAttributes(future: Future, eventLoop, device,
954973
pyWriteAttributes = pyWriteAttributesArrayType()
955974
for idx, attr in enumerate(attributes):
956975
if attr.Attribute.must_use_timed_write and timedRequestTimeoutMs is None or timedRequestTimeoutMs == 0:
957-
raise chip.interaction_model.InteractionModelError(chip.interaction_model.Status.NeedsTimedInteraction)
976+
raise chip.interaction_model.InteractionModelError(
977+
chip.interaction_model.Status.NeedsTimedInteraction)
958978

959979
tlv = attr.Attribute.ToTLV(None, attr.Data)
960980

961-
pyWriteAttributes[idx].attributePath.endpointId = c_uint16(attr.EndpointId)
962-
pyWriteAttributes[idx].attributePath.clusterId = c_uint32(attr.Attribute.cluster_id)
963-
pyWriteAttributes[idx].attributePath.attributeId = c_uint32(attr.Attribute.attribute_id)
964-
pyWriteAttributes[idx].attributePath.dataVersion = c_uint32(attr.DataVersion)
965-
pyWriteAttributes[idx].attributePath.hasDataVersion = c_uint8(attr.HasDataVersion)
966-
pyWriteAttributes[idx].tlvData = cast(ctypes.c_char_p(bytes(tlv)), c_void_p)
981+
pyWriteAttributes[idx].attributePath.endpointId = c_uint16(
982+
attr.EndpointId)
983+
pyWriteAttributes[idx].attributePath.clusterId = c_uint32(
984+
attr.Attribute.cluster_id)
985+
pyWriteAttributes[idx].attributePath.attributeId = c_uint32(
986+
attr.Attribute.attribute_id)
987+
pyWriteAttributes[idx].attributePath.dataVersion = c_uint32(
988+
attr.DataVersion)
989+
pyWriteAttributes[idx].attributePath.hasDataVersion = c_uint8(
990+
attr.HasDataVersion)
991+
pyWriteAttributes[idx].tlvData = cast(
992+
ctypes.c_char_p(bytes(tlv)), c_void_p)
967993
pyWriteAttributes[idx].tlvLength = c_size_t(len(tlv))
968994

969995
transaction = AsyncWriteTransaction(future, eventLoop)
970996
ctypes.pythonapi.Py_IncRef(ctypes.py_object(transaction))
971997
res = builtins.chipStack.Call(
972998
lambda: handle.pychip_WriteClient_WriteAttributes(
973999
ctypes.py_object(transaction), device,
974-
ctypes.c_size_t(0 if timedRequestTimeoutMs is None else timedRequestTimeoutMs),
975-
ctypes.c_size_t(0 if interactionTimeoutMs is None else interactionTimeoutMs),
1000+
ctypes.c_size_t(
1001+
0 if timedRequestTimeoutMs is None else timedRequestTimeoutMs),
1002+
ctypes.c_size_t(
1003+
0 if interactionTimeoutMs is None else interactionTimeoutMs),
9761004
ctypes.c_size_t(0 if busyWaitMs is None else busyWaitMs),
9771005
pyWriteAttributes, ctypes.c_size_t(numberOfAttributes))
9781006
)
@@ -991,12 +1019,18 @@ def WriteGroupAttributes(groupId: int, devCtrl: c_void_p, attributes: List[Attri
9911019

9921020
tlv = attr.Attribute.ToTLV(None, attr.Data)
9931021

994-
pyWriteAttributes[idx].attributePath.endpointId = c_uint16(attr.EndpointId)
995-
pyWriteAttributes[idx].attributePath.clusterId = c_uint32(attr.Attribute.cluster_id)
996-
pyWriteAttributes[idx].attributePath.attributeId = c_uint32(attr.Attribute.attribute_id)
997-
pyWriteAttributes[idx].attributePath.dataVersion = c_uint32(attr.DataVersion)
998-
pyWriteAttributes[idx].attributePath.hasDataVersion = c_uint8(attr.HasDataVersion)
999-
pyWriteAttributes[idx].tlvData = cast(ctypes.c_char_p(bytes(tlv)), c_void_p)
1022+
pyWriteAttributes[idx].attributePath.endpointId = c_uint16(
1023+
attr.EndpointId)
1024+
pyWriteAttributes[idx].attributePath.clusterId = c_uint32(
1025+
attr.Attribute.cluster_id)
1026+
pyWriteAttributes[idx].attributePath.attributeId = c_uint32(
1027+
attr.Attribute.attribute_id)
1028+
pyWriteAttributes[idx].attributePath.dataVersion = c_uint32(
1029+
attr.DataVersion)
1030+
pyWriteAttributes[idx].attributePath.hasDataVersion = c_uint8(
1031+
attr.HasDataVersion)
1032+
pyWriteAttributes[idx].tlvData = cast(
1033+
ctypes.c_char_p(bytes(tlv)), c_void_p)
10001034
pyWriteAttributes[idx].tlvLength = c_size_t(len(tlv))
10011035

10021036
return builtins.chipStack.Call(
@@ -1071,7 +1105,8 @@ def Read(transaction: AsyncReadTransaction, device,
10711105
"DataVersionFilter must provide DataVersion.")
10721106
filter = chip.interaction_model.DataVersionFilterIBstruct.build(
10731107
filter)
1074-
dataVersionFiltersForCffi[idx] = cast(ctypes.c_char_p(filter), c_void_p)
1108+
dataVersionFiltersForCffi[idx] = cast(
1109+
ctypes.c_char_p(filter), c_void_p)
10751110

10761111
eventPathsForCffi = None
10771112
if events is not None:
@@ -1095,7 +1130,6 @@ def Read(transaction: AsyncReadTransaction, device,
10951130
eventPathsForCffi[idx] = cast(ctypes.c_char_p(path), c_void_p)
10961131

10971132
readClientObj = ctypes.POINTER(c_void_p)()
1098-
readCallbackObj = ctypes.POINTER(c_void_p)()
10991133

11001134
ctypes.pythonapi.Py_IncRef(ctypes.py_object(transaction))
11011135
params = _ReadParams.parse(b'\x00' * _ReadParams.sizeof())
@@ -1109,13 +1143,13 @@ def Read(transaction: AsyncReadTransaction, device,
11091143
params = _ReadParams.build(params)
11101144
eventNumberFilterPtr = ctypes.POINTER(ctypes.c_ulonglong)()
11111145
if eventNumberFilter is not None:
1112-
eventNumberFilterPtr = ctypes.POINTER(ctypes.c_ulonglong)(ctypes.c_ulonglong(eventNumberFilter))
1146+
eventNumberFilterPtr = ctypes.POINTER(ctypes.c_ulonglong)(
1147+
ctypes.c_ulonglong(eventNumberFilter))
11131148

11141149
res = builtins.chipStack.Call(
11151150
lambda: handle.pychip_ReadClient_Read(
11161151
ctypes.py_object(transaction),
11171152
ctypes.byref(readClientObj),
1118-
ctypes.byref(readCallbackObj),
11191153
device,
11201154
ctypes.c_char_p(params),
11211155
attributePathsForCffi,
@@ -1127,7 +1161,7 @@ def Read(transaction: AsyncReadTransaction, device,
11271161
ctypes.c_size_t(0 if events is None else len(events)),
11281162
eventNumberFilterPtr))
11291163

1130-
transaction.SetClientObjPointers(readClientObj, readCallbackObj)
1164+
transaction.SetClientObjPointers(readClientObj)
11311165

11321166
if not res.is_success:
11331167
ctypes.pythonapi.Py_DecRef(ctypes.py_object(transaction))

src/controller/python/chip/clusters/attribute.cpp

+16-9
Original file line numberDiff line numberDiff line change
@@ -453,12 +453,20 @@ PyChipError pychip_WriteClient_WriteGroupAttributes(size_t groupIdSizeT, chip::C
453453
return ToPyChipError(err);
454454
}
455455

456-
void pychip_ReadClient_Abort(ReadClient * apReadClient, ReadClientCallback * apCallback)
456+
void pychip_ReadClient_ShutdownSubscription(ReadClient * apReadClient)
457457
{
458-
VerifyOrDie(apReadClient != nullptr);
459-
VerifyOrDie(apCallback != nullptr);
458+
// If apReadClient is nullptr, it means that its life cycle has ended (such as an error happend), and nothing needs to be done.
459+
VerifyOrReturn(apReadClient != nullptr);
460+
// If it is not SubscriptionType, this function should not be executed.
461+
VerifyOrDie(apReadClient->IsSubscriptionType());
460462

461-
delete apCallback;
463+
Optional<SubscriptionId> subscriptionId = apReadClient->GetSubscriptionId();
464+
VerifyOrDie(subscriptionId.HasValue());
465+
466+
FabricIndex fabricIndex = apReadClient->GetFabricIndex();
467+
NodeId nodeId = apReadClient->GetPeerNodeId();
468+
469+
InteractionModelEngine::GetInstance()->ShutdownSubscription(ScopedNodeId(nodeId, fabricIndex), subscriptionId.Value());
462470
}
463471

464472
void pychip_ReadClient_OverrideLivenessTimeout(ReadClient * pReadClient, uint32_t livenessTimeoutMs)
@@ -497,10 +505,10 @@ void pychip_ReadClient_GetSubscriptionTimeoutMs(ReadClient * pReadClient, uint32
497505
}
498506
}
499507

500-
PyChipError pychip_ReadClient_Read(void * appContext, ReadClient ** pReadClient, ReadClientCallback ** pCallback,
501-
DeviceProxy * device, uint8_t * readParamsBuf, void ** attributePathsFromPython,
502-
size_t numAttributePaths, void ** dataversionFiltersFromPython, size_t numDataversionFilters,
503-
void ** eventPathsFromPython, size_t numEventPaths, uint64_t * eventNumberFilter)
508+
PyChipError pychip_ReadClient_Read(void * appContext, ReadClient ** pReadClient, DeviceProxy * device, uint8_t * readParamsBuf,
509+
void ** attributePathsFromPython, size_t numAttributePaths, void ** dataversionFiltersFromPython,
510+
size_t numDataversionFilters, void ** eventPathsFromPython, size_t numEventPaths,
511+
uint64_t * eventNumberFilter)
504512
{
505513
CHIP_ERROR err = CHIP_NO_ERROR;
506514
PyReadAttributeParams pyParams = {};
@@ -612,7 +620,6 @@ PyChipError pychip_ReadClient_Read(void * appContext, ReadClient ** pReadClient,
612620
}
613621

614622
*pReadClient = readClient.get();
615-
*pCallback = callback.get();
616623

617624
callback->AdoptReadClient(std::move(readClient));
618625

0 commit comments

Comments
 (0)