Skip to content

Commit

Permalink
refactor: rename object_detector to predictor
Browse files Browse the repository at this point in the history
  • Loading branch information
PaulHax committed Jan 15, 2025
1 parent 36cb3de commit e326e4e
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 17 deletions.
13 changes: 11 additions & 2 deletions src/nrtk_explorer/app/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@


class LazyDict(Mapping):
"""If function provided for value, run function when value is accessed"""

def __init__(self, *args, **kw):
self._raw_dict = dict(*args, **kw)

Expand Down Expand Up @@ -244,7 +246,7 @@ def delete_meta_state(old_ids, new_ids):
def on_server_ready(self, *args, **kwargs):
self.state.change("inference_model")(self.on_inference_model_change)
self.state.change("current_dataset")(self._cancel_update_images)
self.state.change("current_dataset")(self.reset_detector)
self.state.change("current_dataset")(self.reset_predictor)
self.state.change("confidence_score_threshold")(self._start_update_images)

def on_inference_model_change(self, **kwargs):
Expand All @@ -253,7 +255,7 @@ def on_inference_model_change(self, **kwargs):
self.predictor.set_model(self.state.inference_model)
self._start_update_images()

def reset_detector(self, **kwargs):
def reset_predictor(self, **kwargs):
self.predictor.reset()

def set_on_transform(self, fn):
Expand Down Expand Up @@ -403,6 +405,13 @@ def _cancel_update_images(self, **kwargs):
self._update_task.cancel()

def _start_update_images(self, **kwargs):
"""
After updating the images visible in the image list, all other selected
images are updated and their scores computed. After images are scored,
the table sort may have changed the images that are visible, so
ImageList is asked to send visible image IDs again, which may trigger
a new _update_all_images if the set of images in view has changed.
"""
self._cancel_update_images()
self._update_task = asynchronous.create_task(
self._update_all_images(self.visible_dataset_ids)
Expand Down
10 changes: 5 additions & 5 deletions src/nrtk_explorer/library/multiprocess_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
import logging
import queue
import uuid
from .object_detector import ObjectDetector
from .predictor import Predictor


def _child_worker(request_queue, result_queue, model_name, force_cpu):
logger = logging.getLogger(__name__)
detector = ObjectDetector(model_name=model_name, force_cpu=force_cpu)
predictor = Predictor(model_name=model_name, force_cpu=force_cpu)

while True:
try:
Expand All @@ -28,7 +28,7 @@ def _child_worker(request_queue, result_queue, model_name, force_cpu):

if command == "SET_MODEL":
try:
detector = ObjectDetector(
predictor = Predictor(
model_name=payload["model_name"],
force_cpu=payload["force_cpu"],
)
Expand All @@ -38,14 +38,14 @@ def _child_worker(request_queue, result_queue, model_name, force_cpu):
result_queue.put((req_id, {"status": "ERROR", "message": str(e)}))
elif command == "INFER":
try:
predictions = detector.eval(payload["images"])
predictions = predictor.eval(payload["images"])
result_queue.put((req_id, {"status": "OK", "result": predictions}))
except Exception as e:
logger.exception("Inference failed.")
result_queue.put((req_id, {"status": "ERROR", "message": str(e)}))
elif command == "RESET":
try:
detector.reset()
predictor.reset()
result_queue.put((req_id, {"status": "OK"}))
except Exception as e:
logger.exception("Reset failed.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ class ImageWithId(NamedTuple):
STARTING_BATCH_SIZE = 32


class ObjectDetector:
"""Object detection using Hugging Face's transformers library"""
class Predictor:

def __init__(
self,
Expand Down Expand Up @@ -106,7 +105,7 @@ def eval(
self.batch_size = self.batch_size // 2
self.batch_size = self.batch_size
print(
f"Caught out of memory exception:\n{e}\nWas batch_size={previous_batch_size}, setting batch_size={self.batch_size}"
f"Changing pipeline batch_size from {previous_batch_size} to {self.batch_size} because caught out of memory exception:\n{e}"
)
else:
raise
Expand Down
14 changes: 7 additions & 7 deletions tests/test_object_detector.py → tests/test_predictor.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import pytest
from nrtk_explorer.library import object_detector
from nrtk_explorer.library.predictor import Predictor
from nrtk_explorer.library.multiprocess_predictor import MultiprocessPredictor
from nrtk_explorer.library.scoring import compute_score
from nrtk_explorer.library.dataset import get_dataset
from utils import get_images, DATASET
from nrtk_explorer.library.multiprocess_predictor import MultiprocessPredictor


def test_detector_small():
def test_predictor_small():
sample = get_images()
detector = object_detector.ObjectDetector(model_name="hustvl/yolos-tiny")
img = detector.eval(sample)
predictor = Predictor(model_name="hustvl/yolos-tiny")
img = predictor.eval(sample)
assert len(img) == len(sample.keys())


Expand Down Expand Up @@ -46,8 +46,8 @@ def test_set_model(predictor):
def test_scorer():
ds = get_dataset(DATASET)
sample = get_images()
detector = object_detector.ObjectDetector(model_name="facebook/detr-resnet-50")
predictions = detector.eval(sample)
predictor = Predictor(model_name="facebook/detr-resnet-50")
predictions = predictor.eval(sample)

dataset_annotations = dict()
for annotation in ds.anns.values():
Expand Down

0 comments on commit e326e4e

Please sign in to comment.