Skip to content

Commit

Permalink
refactor: rename object_detector to predictor
Browse files Browse the repository at this point in the history
  • Loading branch information
PaulHax committed Jan 15, 2025
1 parent 36cb3de commit e224a44
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 17 deletions.
6 changes: 4 additions & 2 deletions src/nrtk_explorer/app/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@


class LazyDict(Mapping):
"""If function provided for value, run function when value is accessed"""

def __init__(self, *args, **kw):
self._raw_dict = dict(*args, **kw)

Expand Down Expand Up @@ -244,7 +246,7 @@ def delete_meta_state(old_ids, new_ids):
def on_server_ready(self, *args, **kwargs):
self.state.change("inference_model")(self.on_inference_model_change)
self.state.change("current_dataset")(self._cancel_update_images)
self.state.change("current_dataset")(self.reset_detector)
self.state.change("current_dataset")(self.reset_predictor)
self.state.change("confidence_score_threshold")(self._start_update_images)

def on_inference_model_change(self, **kwargs):
Expand All @@ -253,7 +255,7 @@ def on_inference_model_change(self, **kwargs):
self.predictor.set_model(self.state.inference_model)
self._start_update_images()

def reset_detector(self, **kwargs):
def reset_predictor(self, **kwargs):
self.predictor.reset()

def set_on_transform(self, fn):
Expand Down
10 changes: 5 additions & 5 deletions src/nrtk_explorer/library/multiprocess_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
import logging
import queue
import uuid
from .object_detector import ObjectDetector
from .predictor import Predictor


def _child_worker(request_queue, result_queue, model_name, force_cpu):
logger = logging.getLogger(__name__)
detector = ObjectDetector(model_name=model_name, force_cpu=force_cpu)
predictor = Predictor(model_name=model_name, force_cpu=force_cpu)

while True:
try:
Expand All @@ -28,7 +28,7 @@ def _child_worker(request_queue, result_queue, model_name, force_cpu):

if command == "SET_MODEL":
try:
detector = ObjectDetector(
predictor = Predictor(
model_name=payload["model_name"],
force_cpu=payload["force_cpu"],
)
Expand All @@ -38,14 +38,14 @@ def _child_worker(request_queue, result_queue, model_name, force_cpu):
result_queue.put((req_id, {"status": "ERROR", "message": str(e)}))
elif command == "INFER":
try:
predictions = detector.eval(payload["images"])
predictions = predictor.eval(payload["images"])
result_queue.put((req_id, {"status": "OK", "result": predictions}))
except Exception as e:
logger.exception("Inference failed.")
result_queue.put((req_id, {"status": "ERROR", "message": str(e)}))
elif command == "RESET":
try:
detector.reset()
predictor.reset()
result_queue.put((req_id, {"status": "OK"}))
except Exception as e:
logger.exception("Reset failed.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ class ImageWithId(NamedTuple):
STARTING_BATCH_SIZE = 32


class ObjectDetector:
"""Object detection using Hugging Face's transformers library"""
class Predictor:

def __init__(
self,
Expand Down Expand Up @@ -106,7 +105,7 @@ def eval(
self.batch_size = self.batch_size // 2
self.batch_size = self.batch_size
print(
f"Caught out of memory exception:\n{e}\nWas batch_size={previous_batch_size}, setting batch_size={self.batch_size}"
f"Changing pipeline batch_size from {previous_batch_size} to {self.batch_size} because caught out of memory exception:\n{e}"
)
else:
raise
Expand Down
14 changes: 7 additions & 7 deletions tests/test_object_detector.py → tests/test_predictor.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import pytest
from nrtk_explorer.library import object_detector
from nrtk_explorer.library.predictor import Predictor
from nrtk_explorer.library.multiprocess_predictor import MultiprocessPredictor
from nrtk_explorer.library.scoring import compute_score
from nrtk_explorer.library.dataset import get_dataset
from utils import get_images, DATASET
from nrtk_explorer.library.multiprocess_predictor import MultiprocessPredictor


def test_detector_small():
def test_predictor_small():
sample = get_images()
detector = object_detector.ObjectDetector(model_name="hustvl/yolos-tiny")
img = detector.eval(sample)
predictor = Predictor(model_name="hustvl/yolos-tiny")
img = predictor.eval(sample)
assert len(img) == len(sample.keys())


Expand Down Expand Up @@ -46,8 +46,8 @@ def test_set_model(predictor):
def test_scorer():
ds = get_dataset(DATASET)
sample = get_images()
detector = object_detector.ObjectDetector(model_name="facebook/detr-resnet-50")
predictions = detector.eval(sample)
predictor = Predictor(model_name="facebook/detr-resnet-50")
predictions = predictor.eval(sample)

dataset_annotations = dict()
for annotation in ds.anns.values():
Expand Down

0 comments on commit e224a44

Please sign in to comment.