Skip to content

Commit

Permalink
feat(coco_utils): add scoring for classification model
Browse files Browse the repository at this point in the history
  • Loading branch information
PaulHax committed Nov 22, 2024
1 parent 84e339c commit 877484d
Showing 1 changed file with 35 additions and 0 deletions.
35 changes: 35 additions & 0 deletions src/nrtk_explorer/library/coco_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,18 @@ def keys_to_dataset_ids(image_dict):
return {image_id_to_dataset_id(key): value for key, value in image_dict.items()}


def make_cat_ids(dataset):
"""Get category ids from annotations."""
label_to_id = {cat["name"]: cat["id"] for cat in dataset.cats.values()}

def get_cat_id(annotation):
if "category_id" in annotation:
return annotation["category_id"]
return label_to_id.get(annotation["label"], None)

return get_cat_id


def compute_score(dataset, actual_info, predicted_info):
"""Compute score for image ids."""

Expand Down Expand Up @@ -153,6 +165,29 @@ def is_empty(prediction_pair):

actual, predicted, ids = zip(*has_annotations)

all_annotations_have_bbox = all(
"bbox" in annotation
for annotation_list in actual + predicted
for annotation in annotation_list
)

if not all_annotations_have_bbox:
# score with classification method
get_cat_id = make_cat_ids(dataset)
for pair in has_annotations:
actual, predicted, id = pair
actual_cat_ids = [get_cat_id(annotation) for annotation in actual]
predicted_cat_ids = [get_cat_id(annotation) for annotation in predicted]
matching_cat_ids = sum(
1
for cat_id in predicted_cat_ids
if cat_id in actual_cat_ids and cat_id is not None
)
total_cat_ids = len(set(actual_cat_ids + predicted_cat_ids))
score = matching_cat_ids / total_cat_ids if total_cat_ids > 0 else 0.0
scores.append((id, score))
return scores

if actual_info["type"] == "predictions":
actual_converted = convert_from_predictions_to_first_arg(
actual,
Expand Down

0 comments on commit 877484d

Please sign in to comment.