Skip to content

Commit

Permalink
perf(annotations): run model inference in subprocess
Browse files Browse the repository at this point in the history
  • Loading branch information
PaulHax committed Jan 15, 2025
1 parent 8f9f6e3 commit 38fcae7
Show file tree
Hide file tree
Showing 6 changed files with 197 additions and 23 deletions.
2 changes: 1 addition & 1 deletion src/nrtk_explorer/app/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
DEFAULT_DATASETS = [
f"{DIR_NAME}/coco-od-2017/test_val2017.json",
]
NUM_IMAGES_DEFAULT = 200
NUM_IMAGES_DEFAULT = 500
NUM_IMAGES_DEBOUNCE_TIME = 0.3 # seconds


Expand Down
10 changes: 5 additions & 5 deletions src/nrtk_explorer/app/images/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from functools import lru_cache, partial
from PIL import Image
from nrtk_explorer.app.images.cache import LruCache
from nrtk_explorer.library.object_detector import ObjectDetector
from nrtk_explorer.library.multiprocess_predictor import MultiprocessPredictor
from nrtk_explorer.library.scoring import partition


Expand Down Expand Up @@ -67,15 +67,15 @@ def __init__(
self.add_to_cache_callback = add_to_cache_callback
self.delete_from_cache_callback = delete_from_cache_callback

def get_annotations(self, detector: ObjectDetector, id_to_image: Dict[str, Image.Image]):
async def get_annotations(
self, predictor: MultiprocessPredictor, id_to_image: Dict[str, Image.Image]
):
hits, misses = partition(
lambda id: self.cache.get_item(id) is not None, id_to_image.keys()
)

to_detect = {id: id_to_image[id] for id in misses}
predictions = detector.eval(
to_detect,
)
predictions = predictor.infer(to_detect)
for id, annotations in predictions.items():
self.cache.add_item(
id, annotations, self.add_to_cache_callback, self.delete_from_cache_callback
Expand Down
24 changes: 14 additions & 10 deletions src/nrtk_explorer/app/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import nrtk_explorer.library.transforms as trans
import nrtk_explorer.library.nrtk_transforms as nrtk_trans
import nrtk_explorer.library.yaml_transforms as nrtk_yaml
from nrtk_explorer.library import object_detector
from nrtk_explorer.library.multiprocess_predictor import MultiprocessPredictor
from nrtk_explorer.library.app_config import process_config
from nrtk_explorer.library.scoring import (
compute_score,
Expand Down Expand Up @@ -239,21 +239,23 @@ def delete_meta_state(old_ids, new_ids):

self.visible_dataset_ids = [] # set by ImageList via self.on_scroll callback

self.predictor = MultiprocessPredictor(model_name=self.state.inference_model)

def on_server_ready(self, *args, **kwargs):
self.state.change("inference_model")(self.on_inference_model_change)
self.on_inference_model_change()
self.state.change("current_dataset")(self._cancel_update_images)
self.state.change("current_dataset")(self.reset_detector)
self.state.change("confidence_score_threshold")(self._start_update_images)

def on_inference_model_change(self, **kwargs):
self.original_detection_annotations.cache_clear()
self.transformed_detection_annotations.cache_clear()
self.detector = object_detector.ObjectDetector(model_name=self.state.inference_model)
# self.detector = object_detector.ObjectDetector(model_name=self.state.inference_model)
self.predictor.set_model(self.state.inference_model)
self._start_update_images()

def reset_detector(self, **kwargs):
self.detector.reset()
self.predictor.reset()

def set_on_transform(self, fn):
self._on_transform_fn = fn
Expand Down Expand Up @@ -289,8 +291,8 @@ async def update_transformed_images(self, dataset_ids, visible=False):
)

with self.state:
annotations = self.transformed_detection_annotations.get_annotations(
self.detector, id_to_image
annotations = await self.transformed_detection_annotations.get_annotations(
self.predictor, id_to_image
)
await self.server.network_completion

Expand Down Expand Up @@ -328,7 +330,7 @@ async def update_transformed_images(self, dataset_ids, visible=False):
self.on_transform(id_to_image) # inform embeddings app
self.state.flush()

def compute_predictions_original_images(self, dataset_ids):
async def compute_predictions_original_images(self, dataset_ids):
if not self.state.predictions_original_images_enabled:
return

Expand All @@ -341,8 +343,10 @@ def compute_predictions_original_images(self, dataset_ids):
}
)

self.predictions_original_images = self.original_detection_annotations.get_annotations(
self.detector, image_id_to_image
self.predictions_original_images = (
await self.original_detection_annotations.get_annotations(
self.predictor, image_id_to_image
)
)

ground_truth_annotations = self.ground_truth_annotations.get_annotations(dataset_ids)
Expand All @@ -369,7 +373,7 @@ async def _update_images(self, dataset_ids, visible=False):

# always push to state because compute_predictions_original_images updates score metadata
with self.state:
self.compute_predictions_original_images(dataset_ids)
await self.compute_predictions_original_images(dataset_ids)
await self.server.network_completion
# sortable score value may have changed which may have changed images that are in view
self.server.controller.check_images_in_view()
Expand Down
141 changes: 141 additions & 0 deletions src/nrtk_explorer/library/multiprocess_predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import multiprocessing
import signal
import threading
import logging
import queue
import uuid
from .object_detector import ObjectDetector


def _child_worker(request_queue, result_queue, model_name, force_cpu):
logger = logging.getLogger(__name__)
detector = ObjectDetector(model_name=model_name, force_cpu=force_cpu)

while True:
try:
msg = request_queue.get()
except (EOFError, KeyboardInterrupt):
logger.debug("Worker: Exiting on interrupt or queue EOF.")
break
if msg is None: # Exit signal
logger.debug("Worker: Received EXIT command. Shutting down.")
break

command = msg["command"]
req_id = msg["req_id"]
payload = msg.get("payload", {})
logger.debug(f"Worker: Received {command} with ID {req_id}")

if command == "SET_MODEL":
try:
detector = ObjectDetector(
model_name=payload["model_name"],
force_cpu=payload["force_cpu"],
)
result_queue.put((req_id, {"status": "OK"}))
except Exception as e:
logger.exception("Failed to set model.")
result_queue.put((req_id, {"status": "ERROR", "message": str(e)}))
elif command == "INFER":
try:
predictions = detector.eval(payload["images"])
result_queue.put((req_id, {"status": "OK", "result": predictions}))
except Exception as e:
logger.exception("Inference failed.")
result_queue.put((req_id, {"status": "ERROR", "message": str(e)}))
elif command == "RESET":
try:
detector.reset()
result_queue.put((req_id, {"status": "OK"}))
except Exception as e:
logger.exception("Reset failed.")
result_queue.put((req_id, {"status": "ERROR", "message": str(e)}))

logger.debug("Worker: shutting down.")


class MultiprocessPredictor:
def __init__(self, model_name="facebook/detr-resnet-50", force_cpu=False):
self._lock = threading.Lock()
self.model_name = model_name
self.force_cpu = force_cpu
self._proc = None
self._request_queue = None
self._result_queue = None
self._start_process()

def handle_shutdown(signum, frame):
self.shutdown()

signal.signal(signal.SIGINT, handle_shutdown)

def _start_process(self):
with self._lock:
if self._proc is not None and self._proc.is_alive():
self.shutdown()
multiprocessing.set_start_method("spawn", force=True)
self._request_queue = multiprocessing.Queue()
self._result_queue = multiprocessing.Queue()
self._proc = multiprocessing.Process(
target=_child_worker,
args=(
self._request_queue,
self._result_queue,
self.model_name,
self.force_cpu,
),
daemon=True,
)
self._proc.start()

def set_model(self, model_name, force_cpu=False):
with self._lock:
self.model_name = model_name
self.force_cpu = force_cpu
req_id = str(uuid.uuid4())
self._request_queue.put(
{
"command": "SET_MODEL",
"req_id": req_id,
"payload": {
"model_name": self.model_name,
"force_cpu": self.force_cpu,
},
}
)
return self._wait_for_response(req_id)

def infer(self, images):
if not images:
return {}
with self._lock:
req_id = str(uuid.uuid4())
new_req = {"command": "INFER", "req_id": req_id, "payload": {"images": images}}
self._request_queue.put(new_req)

resp = self._wait_for_response(req_id)
return resp.get("result")

def reset(self):
with self._lock:
req_id = str(uuid.uuid4())
self._request_queue.put({"command": "RESET", "req_id": req_id})
return self._wait_for_response(req_id)

def _wait_for_response(self, req_id):
while True:
try:
r_id, data = self._result_queue.get(timeout=40)
except queue.Empty:
raise TimeoutError("No response from worker.")
if r_id == req_id:
return data

def shutdown(self):
with self._lock:
try:
self._request_queue.put(None)
except Exception:
logging.warning("Could not send exit message to worker.")
if self._proc:
self._proc.join()
11 changes: 4 additions & 7 deletions src/nrtk_explorer/library/object_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,9 @@ def pipeline(self) -> transformers.pipeline:
@pipeline.setter
def pipeline(self, model_name: str):
"""Set the pipeline for object detection using Hugging Face's transformers library"""
if self.task is None:
self._pipeline = transformers.pipeline(model=model_name, device=self.device)
else:
self._pipeline = transformers.pipeline(
model=model_name, device=self.device, task=self.task
)
self._pipeline = transformers.pipeline(
model=model_name, device=self.device, task=self.task, use_fast=True
)
# Do not display warnings
transformers.utils.logging.set_verbosity_error()

Expand Down Expand Up @@ -109,7 +106,7 @@ def eval(
self.batch_size = self.batch_size // 2
self.batch_size = self.batch_size
print(
f"OOM (Pytorch exception {e}) due to batch_size={previous_batch_size}, setting batch_size={self.batch_size}"
f"Caught out of memory exception:\n{e}\nWas batch_size={previous_batch_size}, setting batch_size={self.batch_size}"
)
else:
raise
Expand Down
32 changes: 32 additions & 0 deletions tests/test_object_detector.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import pytest
from nrtk_explorer.library import object_detector
from nrtk_explorer.library.scoring import compute_score
from nrtk_explorer.library.dataset import get_dataset
from utils import get_images, DATASET
from nrtk_explorer.library.multiprocess_predictor import MultiprocessPredictor


def test_detector_small():
Expand All @@ -11,6 +13,36 @@ def test_detector_small():
assert len(img) == len(sample.keys())


@pytest.fixture
def predictor():
predictor = MultiprocessPredictor(model_name="facebook/detr-resnet-50")
yield predictor
predictor.shutdown()


def test_detect(predictor):
"""Test the detect method with sample images."""
images = get_images()
results = predictor.infer(images)
assert len(results) == len(images), "Number of results should match number of images"
for img_id, preds in results.items():
assert isinstance(preds, list), f"Predictions for {img_id} should be a list"


def test_set_model(predictor):
"""Test setting a new model and performing detection."""
predictor.set_model(model_name="hustvl/yolos-tiny")
images = get_images()
results = predictor.infer(images)
assert len(results) == len(
images
), "Number of results should match number of images after setting new model"
for img_id, preds in results.items():
assert isinstance(
preds, list
), f"Predictions for {img_id} should be a list after setting new model"


def test_scorer():
ds = get_dataset(DATASET)
sample = get_images()
Expand Down

0 comments on commit 38fcae7

Please sign in to comment.