Skip to content

Commit

Permalink
refactor(object_detector): eval returns image_id keyed dict
Browse files Browse the repository at this point in the history
  • Loading branch information
PaulHax committed Jun 20, 2024
1 parent 573b474 commit da4d9f9
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 33 deletions.
5 changes: 3 additions & 2 deletions src/nrtk_explorer/app/image_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
DatasetId = str
SourceImageId = str
TransformedImageId = str
ImageId = Union[SourceImageId, TransformedImageId]
ResultId = str


def image_id_to_dataset_id(image_id: Union[SourceImageId, TransformedImageId]) -> DatasetId:
def image_id_to_dataset_id(image_id: ImageId) -> DatasetId:
return image_id.split("_")[-1]


def image_id_to_result_id(image_id: Union[SourceImageId, TransformedImageId]) -> ResultId:
def image_id_to_result_id(image_id: ImageId) -> ResultId:
return f"result_{image_id}"
2 changes: 1 addition & 1 deletion src/nrtk_explorer/app/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def compute_annotations(self, ids):

predictions = self.detector.eval(image_ids=ids, content=self.context.image_objects)

for id_, annotations in predictions:
for id_, annotations in predictions.items():
image_annotations = []
for prediction in annotations:
category_id = None
Expand Down
4 changes: 2 additions & 2 deletions src/nrtk_explorer/library/coco_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ def convert_from_predictions_to_first_arg(predictions, dataset, ids):
def convert_from_predictions_to_second_arg(predictions):
"""Convert predictions to COCOScorer format"""
annotations_predictions = list()
for img_predictions in predictions:
for img_predictions in predictions.values():
current_annotations = list()
for prediction in img_predictions[1]:
for prediction in img_predictions:
if prediction:
current_annotations.append(
(
Expand Down
44 changes: 16 additions & 28 deletions src/nrtk_explorer/library/object_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,6 @@

from nrtk_explorer.library import images_manager

Annotation = dict # in COCO format
Annotations = list[Annotation]
ImageId = str
AnnotatedImage = tuple[ImageId, Annotations]
AnnotatedImages = list[AnnotatedImage]


class ObjectDetector:
"""Object detection using Hugging Face's transformers library"""
Expand Down Expand Up @@ -66,40 +60,34 @@ def eval(
image_ids: list[str],
content: Optional[dict] = None,
batch_size: int = 32,
) -> AnnotatedImages:
):
"""Compute object recognition. Returns Annotations grouped by input image paths."""
images: dict = {}

# Some models require all the images in a batch to be the same size,
# otherwise crash or UB.
batches: dict = {}
for path in image_ids:
img = None
if content and path in content:
img = content[path]
else:
img = self.manager.load_image(path)

images.setdefault(img.size, [[], []])
images[img.size][0].append(path)
images[img.size][1].append(img)
batches.setdefault(img.size, [[], []])
batches[img.size][0].append(path)
batches[img.size][1].append(img)

# Call by each group
predictions = [
list(
zip(
group[0],
self.pipeline(group[1], batch_size=batch_size),
)
predictions_in_baches = [
zip(
image_ids,
self.pipeline(images, batch_size=batch_size),
)
for group in images.values()
for image_ids, images in batches.values()
]
# Flatten the list of predictions
predictions = reduce(operator.iadd, predictions, [])

# order output by paths order
find_prediction = lambda id: next(
prediction for prediction in predictions if prediction[0] == id
)
output = [find_prediction(id) for id in image_ids]
# mypy wrongly thinks output's type is list[list[tuple[str, dict]]]
return output # type: ignore
predictions_by_image_id = {
image_id: predictions
for batch in predictions_in_baches
for image_id, predictions in batch
}
return predictions_by_image_id

0 comments on commit da4d9f9

Please sign in to comment.