Skip to content

Commit

Permalink
fix(multiprocess_predictor): dont block main process when running inf…
Browse files Browse the repository at this point in the history
…erence
  • Loading branch information
PaulHax committed Jan 17, 2025
1 parent e095bd0 commit 86f791e
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 9 deletions.
2 changes: 1 addition & 1 deletion src/nrtk_explorer/app/images/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ async def get_annotations(
)

to_detect = {id: id_to_image[id] for id in misses}
predictions = predictor.infer(to_detect)
predictions = await predictor.infer(to_detect)
for id, annotations in predictions.items():
self.cache.add_item(
id, annotations, self.add_to_cache_callback, self.delete_from_cache_callback
Expand Down
24 changes: 16 additions & 8 deletions src/nrtk_explorer/library/multiprocess_predictor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import multiprocessing
import asyncio
import signal
import threading
import logging
Expand All @@ -8,6 +9,7 @@


def _child_worker(request_queue, result_queue, model_name, force_cpu):
signal.signal(signal.SIGINT, signal.SIG_IGN) # Ignore Ctrl+C in child
logger = logging.getLogger(__name__)
predictor = Predictor(model_name=model_name, force_cpu=force_cpu)

Expand Down Expand Up @@ -105,32 +107,38 @@ def set_model(self, model_name, force_cpu=False):
)
return self._wait_for_response(req_id)

def infer(self, images):
async def infer(self, images):
if not images:
return {}
with self._lock:
req_id = str(uuid.uuid4())
new_req = {"command": "INFER", "req_id": req_id, "payload": {"images": images}}
self._request_queue.put(new_req)

resp = self._wait_for_response(req_id)
resp = await self.__wait_for_response(req_id)
return resp.get("result")

def reset(self):
with self._lock:
req_id = str(uuid.uuid4())
self._request_queue.put({"command": "RESET", "req_id": req_id})
return self._wait_for_response(req_id)
async def __wait_for_response(self, req_id):
return await asyncio.get_event_loop().run_in_executor(None, self._get_response, req_id, 40)

def _wait_for_response(self, req_id):
return self._get_response(req_id, 40)

def _get_response(self, req_id, timeout=40):
while True:
try:
r_id, data = self._result_queue.get(timeout=40)
r_id, data = self._result_queue.get(timeout=timeout)
except queue.Empty:
raise TimeoutError("No response from worker.")
if r_id == req_id:
return data

def reset(self):
with self._lock:
req_id = str(uuid.uuid4())
self._request_queue.put({"command": "RESET", "req_id": req_id})
return self._wait_for_response(req_id)

def shutdown(self):
with self._lock:
try:
Expand Down

0 comments on commit 86f791e

Please sign in to comment.