diff --git a/src/nrtk_explorer/app/core.py b/src/nrtk_explorer/app/core.py index fc16642..00c8c6f 100644 --- a/src/nrtk_explorer/app/core.py +++ b/src/nrtk_explorer/app/core.py @@ -32,7 +32,7 @@ DEFAULT_DATASETS = [ f"{DIR_NAME}/coco-od-2017/test_val2017.json", ] -NUM_IMAGES_DEFAULT = 200 +NUM_IMAGES_DEFAULT = 500 NUM_IMAGES_DEBOUNCE_TIME = 0.3 # seconds diff --git a/src/nrtk_explorer/app/images/annotations.py b/src/nrtk_explorer/app/images/annotations.py index e3acef7..dcf9ce3 100644 --- a/src/nrtk_explorer/app/images/annotations.py +++ b/src/nrtk_explorer/app/images/annotations.py @@ -2,7 +2,7 @@ from functools import lru_cache, partial from PIL import Image from nrtk_explorer.app.images.cache import LruCache -from nrtk_explorer.library.object_detector import ObjectDetector +from nrtk_explorer.library.multiprocess_predictor import MultiprocessPredictor from nrtk_explorer.library.scoring import partition @@ -67,15 +67,15 @@ def __init__( self.add_to_cache_callback = add_to_cache_callback self.delete_from_cache_callback = delete_from_cache_callback - def get_annotations(self, detector: ObjectDetector, id_to_image: Dict[str, Image.Image]): + async def get_annotations( + self, predictor: MultiprocessPredictor, id_to_image: Dict[str, Image.Image] + ): hits, misses = partition( lambda id: self.cache.get_item(id) is not None, id_to_image.keys() ) to_detect = {id: id_to_image[id] for id in misses} - predictions = detector.eval( - to_detect, - ) + predictions = predictor.infer(to_detect) for id, annotations in predictions.items(): self.cache.add_item( id, annotations, self.add_to_cache_callback, self.delete_from_cache_callback diff --git a/src/nrtk_explorer/app/transforms.py b/src/nrtk_explorer/app/transforms.py index ee21330..e3b613a 100644 --- a/src/nrtk_explorer/app/transforms.py +++ b/src/nrtk_explorer/app/transforms.py @@ -11,7 +11,7 @@ import nrtk_explorer.library.transforms as trans import nrtk_explorer.library.nrtk_transforms as nrtk_trans import nrtk_explorer.library.yaml_transforms as nrtk_yaml -from nrtk_explorer.library import object_detector +from nrtk_explorer.library.multiprocess_predictor import MultiprocessPredictor from nrtk_explorer.library.app_config import process_config from nrtk_explorer.library.scoring import ( compute_score, @@ -239,9 +239,10 @@ def delete_meta_state(old_ids, new_ids): self.visible_dataset_ids = [] # set by ImageList via self.on_scroll callback + self.predictor = MultiprocessPredictor(model_name=self.state.inference_model) + def on_server_ready(self, *args, **kwargs): self.state.change("inference_model")(self.on_inference_model_change) - self.on_inference_model_change() self.state.change("current_dataset")(self._cancel_update_images) self.state.change("current_dataset")(self.reset_detector) self.state.change("confidence_score_threshold")(self._start_update_images) @@ -249,11 +250,11 @@ def on_server_ready(self, *args, **kwargs): def on_inference_model_change(self, **kwargs): self.original_detection_annotations.cache_clear() self.transformed_detection_annotations.cache_clear() - self.detector = object_detector.ObjectDetector(model_name=self.state.inference_model) + self.predictor.set_model(self.state.inference_model) self._start_update_images() def reset_detector(self, **kwargs): - self.detector.reset() + self.predictor.reset() def set_on_transform(self, fn): self._on_transform_fn = fn @@ -289,8 +290,8 @@ async def update_transformed_images(self, dataset_ids, visible=False): ) with self.state: - annotations = self.transformed_detection_annotations.get_annotations( - self.detector, id_to_image + annotations = await self.transformed_detection_annotations.get_annotations( + self.predictor, id_to_image ) await self.server.network_completion @@ -328,7 +329,7 @@ async def update_transformed_images(self, dataset_ids, visible=False): self.on_transform(id_to_image) # inform embeddings app self.state.flush() - def compute_predictions_original_images(self, dataset_ids): + async def compute_predictions_original_images(self, dataset_ids): if not self.state.predictions_original_images_enabled: return @@ -341,8 +342,10 @@ def compute_predictions_original_images(self, dataset_ids): } ) - self.predictions_original_images = self.original_detection_annotations.get_annotations( - self.detector, image_id_to_image + self.predictions_original_images = ( + await self.original_detection_annotations.get_annotations( + self.predictor, image_id_to_image + ) ) ground_truth_annotations = self.ground_truth_annotations.get_annotations(dataset_ids) @@ -369,7 +372,7 @@ async def _update_images(self, dataset_ids, visible=False): # always push to state because compute_predictions_original_images updates score metadata with self.state: - self.compute_predictions_original_images(dataset_ids) + await self.compute_predictions_original_images(dataset_ids) await self.server.network_completion # sortable score value may have changed which may have changed images that are in view self.server.controller.check_images_in_view() diff --git a/src/nrtk_explorer/library/multiprocess_predictor.py b/src/nrtk_explorer/library/multiprocess_predictor.py new file mode 100644 index 0000000..6343f4f --- /dev/null +++ b/src/nrtk_explorer/library/multiprocess_predictor.py @@ -0,0 +1,141 @@ +import multiprocessing +import signal +import threading +import logging +import queue +import uuid +from .object_detector import ObjectDetector + + +def _child_worker(request_queue, result_queue, model_name, force_cpu): + logger = logging.getLogger(__name__) + detector = ObjectDetector(model_name=model_name, force_cpu=force_cpu) + + while True: + try: + msg = request_queue.get() + except (EOFError, KeyboardInterrupt): + logger.debug("Worker: Exiting on interrupt or queue EOF.") + break + if msg is None: # Exit signal + logger.debug("Worker: Received EXIT command. Shutting down.") + break + + command = msg["command"] + req_id = msg["req_id"] + payload = msg.get("payload", {}) + logger.debug(f"Worker: Received {command} with ID {req_id}") + + if command == "SET_MODEL": + try: + detector = ObjectDetector( + model_name=payload["model_name"], + force_cpu=payload["force_cpu"], + ) + result_queue.put((req_id, {"status": "OK"})) + except Exception as e: + logger.exception("Failed to set model.") + result_queue.put((req_id, {"status": "ERROR", "message": str(e)})) + elif command == "INFER": + try: + predictions = detector.eval(payload["images"]) + result_queue.put((req_id, {"status": "OK", "result": predictions})) + except Exception as e: + logger.exception("Inference failed.") + result_queue.put((req_id, {"status": "ERROR", "message": str(e)})) + elif command == "RESET": + try: + detector.reset() + result_queue.put((req_id, {"status": "OK"})) + except Exception as e: + logger.exception("Reset failed.") + result_queue.put((req_id, {"status": "ERROR", "message": str(e)})) + + logger.debug("Worker: shutting down.") + + +class MultiprocessPredictor: + def __init__(self, model_name="facebook/detr-resnet-50", force_cpu=False): + self._lock = threading.Lock() + self.model_name = model_name + self.force_cpu = force_cpu + self._proc = None + self._request_queue = None + self._result_queue = None + self._start_process() + + def handle_shutdown(signum, frame): + self.shutdown() + + signal.signal(signal.SIGINT, handle_shutdown) + + def _start_process(self): + with self._lock: + if self._proc is not None and self._proc.is_alive(): + self.shutdown() + multiprocessing.set_start_method("spawn", force=True) + self._request_queue = multiprocessing.Queue() + self._result_queue = multiprocessing.Queue() + self._proc = multiprocessing.Process( + target=_child_worker, + args=( + self._request_queue, + self._result_queue, + self.model_name, + self.force_cpu, + ), + daemon=True, + ) + self._proc.start() + + def set_model(self, model_name, force_cpu=False): + with self._lock: + self.model_name = model_name + self.force_cpu = force_cpu + req_id = str(uuid.uuid4()) + self._request_queue.put( + { + "command": "SET_MODEL", + "req_id": req_id, + "payload": { + "model_name": self.model_name, + "force_cpu": self.force_cpu, + }, + } + ) + return self._wait_for_response(req_id) + + def infer(self, images): + if not images: + return {} + with self._lock: + req_id = str(uuid.uuid4()) + new_req = {"command": "INFER", "req_id": req_id, "payload": {"images": images}} + self._request_queue.put(new_req) + + resp = self._wait_for_response(req_id) + return resp.get("result") + + def reset(self): + with self._lock: + req_id = str(uuid.uuid4()) + self._request_queue.put({"command": "RESET", "req_id": req_id}) + return self._wait_for_response(req_id) + + def _wait_for_response(self, req_id): + while True: + try: + r_id, data = self._result_queue.get(timeout=40) + except queue.Empty: + raise TimeoutError("No response from worker.") + if r_id == req_id: + return data + + def shutdown(self): + with self._lock: + try: + self._request_queue.put(None) + except Exception: + logging.warning("Could not send exit message to worker.") + if self._proc: + self._proc.join() diff --git a/src/nrtk_explorer/library/object_detector.py b/src/nrtk_explorer/library/object_detector.py index 56d9130..038737c 100644 --- a/src/nrtk_explorer/library/object_detector.py +++ b/src/nrtk_explorer/library/object_detector.py @@ -50,12 +50,9 @@ def pipeline(self) -> transformers.pipeline: @pipeline.setter def pipeline(self, model_name: str): """Set the pipeline for object detection using Hugging Face's transformers library""" - if self.task is None: - self._pipeline = transformers.pipeline(model=model_name, device=self.device) - else: - self._pipeline = transformers.pipeline( - model=model_name, device=self.device, task=self.task - ) + self._pipeline = transformers.pipeline( + model=model_name, device=self.device, task=self.task, use_fast=True + ) # Do not display warnings transformers.utils.logging.set_verbosity_error() @@ -109,7 +106,7 @@ def eval( self.batch_size = self.batch_size // 2 self.batch_size = self.batch_size print( - f"OOM (Pytorch exception {e}) due to batch_size={previous_batch_size}, setting batch_size={self.batch_size}" + f"Caught out of memory exception:\n{e}\nWas batch_size={previous_batch_size}, setting batch_size={self.batch_size}" ) else: raise diff --git a/tests/test_object_detector.py b/tests/test_object_detector.py index 765b2c5..8447d04 100644 --- a/tests/test_object_detector.py +++ b/tests/test_object_detector.py @@ -1,7 +1,9 @@ +import pytest from nrtk_explorer.library import object_detector from nrtk_explorer.library.scoring import compute_score from nrtk_explorer.library.dataset import get_dataset from utils import get_images, DATASET +from nrtk_explorer.library.multiprocess_predictor import MultiprocessPredictor def test_detector_small(): @@ -11,6 +13,36 @@ def test_detector_small(): assert len(img) == len(sample.keys()) +@pytest.fixture +def predictor(): + predictor = MultiprocessPredictor(model_name="facebook/detr-resnet-50") + yield predictor + predictor.shutdown() + + +def test_detect(predictor): + """Test the detect method with sample images.""" + images = get_images() + results = predictor.infer(images) + assert len(results) == len(images), "Number of results should match number of images" + for img_id, preds in results.items(): + assert isinstance(preds, list), f"Predictions for {img_id} should be a list" + + +def test_set_model(predictor): + """Test setting a new model and performing detection.""" + predictor.set_model(model_name="hustvl/yolos-tiny") + images = get_images() + results = predictor.infer(images) + assert len(results) == len( + images + ), "Number of results should match number of images after setting new model" + for img_id, preds in results.items(): + assert isinstance( + preds, list + ), f"Predictions for {img_id} should be a list after setting new model" + + def test_scorer(): ds = get_dataset(DATASET) sample = get_images()