Skip to content

Commit 26be584

Browse files
authored
[Python] Implement async friendly GetConnectedDevice (project-chip#32760)
* [Python] Implement async friendly GetConnectedDevice Currently GetConnectedDeviceSync() is blocking e.g. when a new session needs to be created. This is not asyncio friendly as it blocks the whole event loop. Implement a asyncio friendly variant GetConnectedDevice() which is a co-routine function which can be awaited. * Drop bracket * Skip timeout only when passing None Change semantics slightly to make 0 mean don't wait at all. * Add API docs
1 parent 007e11d commit 26be584

File tree

1 file changed

+75
-9
lines changed

1 file changed

+75
-9
lines changed

src/controller/python/chip/ChipDeviceCtrl.py

+75-9
Original file line numberDiff line numberDiff line change
@@ -787,8 +787,16 @@ def GetClusterHandler(self):
787787

788788
return self._Cluster
789789

790-
def GetConnectedDeviceSync(self, nodeid, allowPASE=True, timeoutMs: int = None):
791-
''' Returns DeviceProxyWrapper upon success.'''
790+
def GetConnectedDeviceSync(self, nodeid, allowPASE: bool = True, timeoutMs: int = None):
791+
''' Gets an OperationalDeviceProxy or CommissioneeDeviceProxy for the specified Node.
792+
793+
nodeId: Target's Node ID
794+
allowPASE: Get a device proxy of a device being commissioned.
795+
timeoutMs: Timeout for a timed invoke request. Omit or set to 'None' to indicate a non-timed request.
796+
797+
Returns:
798+
- DeviceProxyWrapper on success
799+
'''
792800
self.CheckIsActive()
793801

794802
returnDevice = c_void_p(None)
@@ -824,7 +832,7 @@ def deviceAvailable(self, device, err):
824832
if returnDevice.value is None:
825833
with deviceAvailableCV:
826834
timeout = None
827-
if (timeoutMs):
835+
if timeoutMs is not None:
828836
timeout = float(timeoutMs) / 1000
829837

830838
ret = deviceAvailableCV.wait(timeout)
@@ -836,6 +844,64 @@ def deviceAvailable(self, device, err):
836844

837845
return DeviceProxyWrapper(returnDevice, self._dmLib)
838846

847+
async def GetConnectedDevice(self, nodeid, allowPASE: bool = True, timeoutMs: int = None):
848+
''' Gets an OperationalDeviceProxy or CommissioneeDeviceProxy for the specified Node.
849+
850+
nodeId: Target's Node ID
851+
allowPASE: Get a device proxy of a device being commissioned.
852+
timeoutMs: Timeout for a timed invoke request. Omit or set to 'None' to indicate a non-timed request.
853+
854+
Returns:
855+
- DeviceProxyWrapper on success
856+
'''
857+
self.CheckIsActive()
858+
859+
if allowPASE:
860+
returnDevice = c_void_p(None)
861+
res = self._ChipStack.Call(lambda: self._dmLib.pychip_GetDeviceBeingCommissioned(
862+
self.devCtrl, nodeid, byref(returnDevice)), timeoutMs)
863+
if res.is_success:
864+
logging.info('Using PASE connection')
865+
return DeviceProxyWrapper(returnDevice)
866+
867+
eventLoop = asyncio.get_running_loop()
868+
future = eventLoop.create_future()
869+
870+
class DeviceAvailableClosure():
871+
def __init__(self, loop, future: asyncio.Future):
872+
self._returnDevice = c_void_p(None)
873+
self._returnErr = None
874+
self._event_loop = loop
875+
self._future = future
876+
877+
def _deviceAvailable(self):
878+
if self._returnDevice.value is not None:
879+
self._future.set_result(self._returnDevice)
880+
else:
881+
self._future.set_exception(self._returnErr.to_exception())
882+
883+
def deviceAvailable(self, device, err):
884+
self._returnDevice = c_void_p(device)
885+
self._returnErr = err
886+
self._event_loop.call_soon_threadsafe(self._deviceAvailable)
887+
ctypes.pythonapi.Py_DecRef(ctypes.py_object(self))
888+
889+
closure = DeviceAvailableClosure(eventLoop, future)
890+
ctypes.pythonapi.Py_IncRef(ctypes.py_object(closure))
891+
self._ChipStack.Call(lambda: self._dmLib.pychip_GetConnectedDeviceByNodeId(
892+
self.devCtrl, nodeid, ctypes.py_object(closure), _DeviceAvailableCallback),
893+
timeoutMs).raise_on_error()
894+
895+
# The callback might have been received synchronously (during self._ChipStack.Call()).
896+
# In that case the Future has already been set it will return immediately
897+
if timeoutMs is not None:
898+
timeout = float(timeoutMs) / 1000
899+
await asyncio.wait_for(future, timeout=timeout)
900+
else:
901+
await future
902+
903+
return DeviceProxyWrapper(future.result(), self._dmLib)
904+
839905
def ComputeRoundTripTimeout(self, nodeid, upperLayerProcessingTimeoutMs: int = 0):
840906
''' Returns a computed timeout value based on the round-trip time it takes for the peer at the other end of the session to
841907
receive a message, process it and send it back. This is computed based on the session type, the type of transport,
@@ -900,7 +966,7 @@ async def TestOnlySendBatchCommands(self, nodeid: int, commands: typing.List[Clu
900966
eventLoop = asyncio.get_running_loop()
901967
future = eventLoop.create_future()
902968

903-
device = self.GetConnectedDeviceSync(nodeid, timeoutMs=interactionTimeoutMs)
969+
device = await self.GetConnectedDevice(nodeid, timeoutMs=interactionTimeoutMs)
904970

905971
ClusterCommand.TestOnlySendBatchCommands(
906972
future, eventLoop, device.deviceProxy, commands,
@@ -921,7 +987,7 @@ async def TestOnlySendCommandTimedRequestFlagWithNoTimedInvoke(self, nodeid: int
921987
eventLoop = asyncio.get_running_loop()
922988
future = eventLoop.create_future()
923989

924-
device = self.GetConnectedDeviceSync(nodeid, timeoutMs=None)
990+
device = await self.GetConnectedDevice(nodeid, timeoutMs=None)
925991
ClusterCommand.TestOnlySendCommandTimedRequestFlagWithNoTimedInvoke(
926992
future, eventLoop, responseType, device.deviceProxy, ClusterCommand.CommandPath(
927993
EndpointId=endpoint,
@@ -953,7 +1019,7 @@ async def SendCommand(self, nodeid: int, endpoint: int, payload: ClusterObjects.
9531019
eventLoop = asyncio.get_running_loop()
9541020
future = eventLoop.create_future()
9551021

956-
device = self.GetConnectedDeviceSync(nodeid, timeoutMs=interactionTimeoutMs)
1022+
device = await self.GetConnectedDevice(nodeid, timeoutMs=interactionTimeoutMs)
9571023
ClusterCommand.SendCommand(
9581024
future, eventLoop, responseType, device.deviceProxy, ClusterCommand.CommandPath(
9591025
EndpointId=endpoint,
@@ -994,7 +1060,7 @@ async def SendBatchCommands(self, nodeid: int, commands: typing.List[ClusterComm
9941060
eventLoop = asyncio.get_running_loop()
9951061
future = eventLoop.create_future()
9961062

997-
device = self.GetConnectedDeviceSync(nodeid, timeoutMs=interactionTimeoutMs)
1063+
device = await self.GetConnectedDevice(nodeid, timeoutMs=interactionTimeoutMs)
9981064

9991065
ClusterCommand.SendBatchCommands(
10001066
future, eventLoop, device.deviceProxy, commands,
@@ -1044,7 +1110,7 @@ async def WriteAttribute(self, nodeid: int,
10441110
eventLoop = asyncio.get_running_loop()
10451111
future = eventLoop.create_future()
10461112

1047-
device = self.GetConnectedDeviceSync(nodeid, timeoutMs=interactionTimeoutMs)
1113+
device = await self.GetConnectedDevice(nodeid, timeoutMs=interactionTimeoutMs)
10481114

10491115
attrs = []
10501116
for v in attributes:
@@ -1272,7 +1338,7 @@ async def Read(self, nodeid: int, attributes: typing.List[typing.Union[
12721338
eventLoop = asyncio.get_running_loop()
12731339
future = eventLoop.create_future()
12741340

1275-
device = self.GetConnectedDeviceSync(nodeid)
1341+
device = await self.GetConnectedDevice(nodeid)
12761342
attributePaths = [self._parseAttributePathTuple(
12771343
v) for v in attributes] if attributes else None
12781344
clusterDataVersionFilters = [self._parseDataVersionFilterTuple(

0 commit comments

Comments
 (0)