Skip to content

Commit

Permalink
Change websocket cleanup performance from O(n^2) to O(n).
Browse files Browse the repository at this point in the history
Move subscription tracking to a dict keyed by the underlying websocket object.

Changes the cleanup from two nested for loops (the 2nd being the
unsubscribe_from method), to a single loop with (the common case) deleting a
single key from a dict.
  • Loading branch information
marshallbrekka committed Jun 17, 2024
1 parent f6f7176 commit 1a1e58d
Showing 1 changed file with 26 additions and 13 deletions.
39 changes: 26 additions & 13 deletions services/ui_backend_service/api/ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class Websocket(object):
Example event:
{"type": "UPDATE", "uuid": "myst3rySh4ck", "resource": "/runs", "data": {"foo": "bar"}}
'''
subscriptions: List[WSSubscription] = []
_subscriptions: Dict[web.WebSocketResponse, List[WSSubscription]] = collections.defaultdict(list)

def __init__(self, app, db, event_emitter=None, queue_ttl: int = WS_QUEUE_TTL_SECONDS, cache=None):
self.event_emitter = event_emitter or AsyncIOEventEmitter()
Expand Down Expand Up @@ -83,7 +83,7 @@ async def event_handler(self, operation: str, resources: List[str], data: Dict,
a dictionary of filters used in the query when fetching complete data.
"""
# Check if event needs to be broadcast (if anyone is subscribed to the resource)
if any(subscription.resource in resources for subscription in self.subscriptions):
if any(subscription.resource in resources for subscription in self.subscriptions()):
# load the data and postprocessor for broadcasting if table
# is provided (otherwise data has already been loaded in advance)
if table_name:
Expand All @@ -108,18 +108,28 @@ async def event_handler(self, operation: str, resources: List[str], data: Dict,
'resources': resources,
'data': _data
})
for subscription in self.subscriptions:
for subscription in self.subscriptions():
try:
if subscription.disconnected_ts and time.time() - subscription.disconnected_ts > WS_QUEUE_TTL_SECONDS:
await self.unsubscribe_from(subscription.ws, subscription.uuid)
# We can assume that all websockets (not just this UUID) are disconnected, don't filter by UUID as well.
await self.unsubscribe_from(subscription.ws)
else:
await self._event_subscription(subscription, operation, resources, _data)
except ConnectionResetError:
self.logger.debug("Trying to broadcast to a stale subscription. Unsubscribing")
await self.unsubscribe_from(subscription.ws, subscription.uuid)
await self.unsubscribe_from(subscription.ws)
except Exception:
self.logger.exception("Broadcasting to subscription failed")

def subscriptions(self):
# Grab all of the keys upfront and use that to iterate so that callers can
# safely modify the subscriptions dict while we are iterating through it.
# This is primarily useful when calling `unsubscribe_from` during the cleanup
# loop in the event handler.
for k in list(self._subscriptions.keys()):
for sub in self._subscriptions[k]:
yield sub

async def _event_subscription(self, subscription: WSSubscription, operation: str, resources: List[str], data: Dict):
for resource in resources:
if subscription.resource == resource:
Expand All @@ -142,7 +152,7 @@ async def subscribe_to(self, ws, uuid: str, resource: str, since: int):
subscription = WSSubscription(
ws=ws, fullpath=resource, resource=_resource, query=query, uuid=uuid,
filter=filter_fn, disconnected_ts=None)
self.subscriptions.append(subscription)
self._subscriptions[ws].append(subscription)

# Send previous events that client might have missed due to disconnection
if since:
Expand All @@ -154,22 +164,25 @@ async def subscribe_to(self, ws, uuid: str, resource: str, since: int):
)

async def unsubscribe_from(self, ws, uuid: str = None):
if ws not in self._subscriptions:
return
if uuid:
self.subscriptions = list(
filter(lambda s: uuid != s.uuid or ws != s.ws, self.subscriptions))
self._subscriptions[ws] = list(
filter(lambda s: uuid != s.uuid or ws != s.ws, self._subscriptions[ws]))
if len(self._subscriptions[ws]) == 0:
del self._subscriptions[ws]
else:
self.subscriptions = list(
filter(lambda s: ws != s.ws, self.subscriptions))
del self._subscriptions[ws]

async def handle_disconnect(self, ws):
"""
Sets disconnected timestamp on websocket subscription without removing it from the list.
Removing is handled by event_handler that checks for expired subscriptions before emitting
"""
self.subscriptions = list(
self._subscriptions[ws] = list(
map(
lambda sub: sub._replace(disconnected_ts=time.time()) if sub.ws == ws else sub,
self.subscriptions)
lambda sub: sub._replace(disconnected_ts=time.time()),
self._subscriptions[ws])
)

async def websocket_handler(self, request):
Expand Down

0 comments on commit 1a1e58d

Please sign in to comment.