diff --git a/src/nrtk_explorer/app/transforms.py b/src/nrtk_explorer/app/transforms.py index e3b613a..fa5a266 100644 --- a/src/nrtk_explorer/app/transforms.py +++ b/src/nrtk_explorer/app/transforms.py @@ -52,6 +52,8 @@ class LazyDict(Mapping): + """If function provided for value, run function when value is accessed""" + def __init__(self, *args, **kw): self._raw_dict = dict(*args, **kw) @@ -244,7 +246,7 @@ def delete_meta_state(old_ids, new_ids): def on_server_ready(self, *args, **kwargs): self.state.change("inference_model")(self.on_inference_model_change) self.state.change("current_dataset")(self._cancel_update_images) - self.state.change("current_dataset")(self.reset_detector) + self.state.change("current_dataset")(self.reset_predictor) self.state.change("confidence_score_threshold")(self._start_update_images) def on_inference_model_change(self, **kwargs): @@ -253,7 +255,7 @@ def on_inference_model_change(self, **kwargs): self.predictor.set_model(self.state.inference_model) self._start_update_images() - def reset_detector(self, **kwargs): + def reset_predictor(self, **kwargs): self.predictor.reset() def set_on_transform(self, fn): @@ -403,6 +405,13 @@ def _cancel_update_images(self, **kwargs): self._update_task.cancel() def _start_update_images(self, **kwargs): + """ + After updating the images visible in the image list, all other selected + images are updated and their scores computed. After images are scored, + the table sort may have changed the images that are visible, so + ImageList is asked to send visible image IDs again, which may trigger + a new _update_all_images if the set of images in view has changed. + """ self._cancel_update_images() self._update_task = asynchronous.create_task( self._update_all_images(self.visible_dataset_ids) diff --git a/src/nrtk_explorer/library/multiprocess_predictor.py b/src/nrtk_explorer/library/multiprocess_predictor.py index 6343f4f..fd3a106 100644 --- a/src/nrtk_explorer/library/multiprocess_predictor.py +++ b/src/nrtk_explorer/library/multiprocess_predictor.py @@ -4,12 +4,12 @@ import logging import queue import uuid -from .object_detector import ObjectDetector +from .predictor import Predictor def _child_worker(request_queue, result_queue, model_name, force_cpu): logger = logging.getLogger(__name__) - detector = ObjectDetector(model_name=model_name, force_cpu=force_cpu) + predictor = Predictor(model_name=model_name, force_cpu=force_cpu) while True: try: @@ -28,7 +28,7 @@ def _child_worker(request_queue, result_queue, model_name, force_cpu): if command == "SET_MODEL": try: - detector = ObjectDetector( + predictor = Predictor( model_name=payload["model_name"], force_cpu=payload["force_cpu"], ) @@ -38,14 +38,14 @@ def _child_worker(request_queue, result_queue, model_name, force_cpu): result_queue.put((req_id, {"status": "ERROR", "message": str(e)})) elif command == "INFER": try: - predictions = detector.eval(payload["images"]) + predictions = predictor.eval(payload["images"]) result_queue.put((req_id, {"status": "OK", "result": predictions})) except Exception as e: logger.exception("Inference failed.") result_queue.put((req_id, {"status": "ERROR", "message": str(e)})) elif command == "RESET": try: - detector.reset() + predictor.reset() result_queue.put((req_id, {"status": "OK"})) except Exception as e: logger.exception("Reset failed.") diff --git a/src/nrtk_explorer/library/object_detector.py b/src/nrtk_explorer/library/predictor.py similarity index 94% rename from src/nrtk_explorer/library/object_detector.py rename to src/nrtk_explorer/library/predictor.py index 038737c..a600cd0 100644 --- a/src/nrtk_explorer/library/object_detector.py +++ b/src/nrtk_explorer/library/predictor.py @@ -17,8 +17,7 @@ class ImageWithId(NamedTuple): STARTING_BATCH_SIZE = 32 -class ObjectDetector: - """Object detection using Hugging Face's transformers library""" +class Predictor: def __init__( self, @@ -106,7 +105,7 @@ def eval( self.batch_size = self.batch_size // 2 self.batch_size = self.batch_size print( - f"Caught out of memory exception:\n{e}\nWas batch_size={previous_batch_size}, setting batch_size={self.batch_size}" + f"Changing pipeline batch_size from {previous_batch_size} to {self.batch_size} because caught out of memory exception:\n{e}" ) else: raise diff --git a/tests/test_object_detector.py b/tests/test_predictor.py similarity index 86% rename from tests/test_object_detector.py rename to tests/test_predictor.py index 8447d04..e434672 100644 --- a/tests/test_object_detector.py +++ b/tests/test_predictor.py @@ -1,15 +1,15 @@ import pytest -from nrtk_explorer.library import object_detector +from nrtk_explorer.library.predictor import Predictor +from nrtk_explorer.library.multiprocess_predictor import MultiprocessPredictor from nrtk_explorer.library.scoring import compute_score from nrtk_explorer.library.dataset import get_dataset from utils import get_images, DATASET -from nrtk_explorer.library.multiprocess_predictor import MultiprocessPredictor -def test_detector_small(): +def test_predictor_small(): sample = get_images() - detector = object_detector.ObjectDetector(model_name="hustvl/yolos-tiny") - img = detector.eval(sample) + predictor = Predictor(model_name="hustvl/yolos-tiny") + img = predictor.eval(sample) assert len(img) == len(sample.keys()) @@ -46,8 +46,8 @@ def test_set_model(predictor): def test_scorer(): ds = get_dataset(DATASET) sample = get_images() - detector = object_detector.ObjectDetector(model_name="facebook/detr-resnet-50") - predictions = detector.eval(sample) + predictor = Predictor(model_name="facebook/detr-resnet-50") + predictions = predictor.eval(sample) dataset_annotations = dict() for annotation in ds.anns.values():