From 6414b7d583affb7ed9b3a807329a08017e0bf8be Mon Sep 17 00:00:00 2001 From: Paul Elliott Date: Wed, 15 Jan 2025 14:06:53 -0500 Subject: [PATCH] refactor(transforms): pass original predictions via func args Rathern than through member variable --- src/nrtk_explorer/app/transforms.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/src/nrtk_explorer/app/transforms.py b/src/nrtk_explorer/app/transforms.py index fa5a266..fe85cf6 100644 --- a/src/nrtk_explorer/app/transforms.py +++ b/src/nrtk_explorer/app/transforms.py @@ -270,7 +270,9 @@ def on_apply_transform(self, **kwargs): self.state.transform_enabled_switch = True self._start_update_images() - async def update_transformed_images(self, dataset_ids, visible=False): + async def update_transformed_images( + self, dataset_ids, predictions_original_images, visible=False + ): if not self.state.transform_enabled: return @@ -310,10 +312,10 @@ async def update_transformed_images(self, dataset_ids, visible=False): ) # depends on original images predictions - if self.state.predictions_original_images_enabled: + if predictions_original_images: scores = compute_score( self.context.dataset, - self.predictions_original_images, + predictions_original_images, annotations, self.state.confidence_score_threshold, ) @@ -344,10 +346,8 @@ async def compute_predictions_original_images(self, dataset_ids): } ) - self.predictions_original_images = ( - await self.original_detection_annotations.get_annotations( - self.predictor, image_id_to_image - ) + predictions_original_images = await self.original_detection_annotations.get_annotations( + self.predictor, image_id_to_image ) ground_truth_annotations = self.ground_truth_annotations.get_annotations(dataset_ids) @@ -355,7 +355,7 @@ async def compute_predictions_original_images(self, dataset_ids): scores = compute_score( self.context.dataset, ground_truth_annotations, - self.predictions_original_images, + predictions_original_images, self.state.confidence_score_threshold, ) for dataset_id, score in scores: @@ -363,6 +363,8 @@ async def compute_predictions_original_images(self, dataset_ids): self.state, dataset_id, {"original_ground_to_original_detection_score": score} ) + return predictions_original_images + async def _update_images(self, dataset_ids, visible=False): if visible: # load images on state for ImageList @@ -374,16 +376,19 @@ async def _update_images(self, dataset_ids, visible=False): # always push to state because compute_predictions_original_images updates score metadata with self.state: - await self.compute_predictions_original_images(dataset_ids) + predictions_original_images = await self.compute_predictions_original_images( + dataset_ids + ) await self.server.network_completion # sortable score value may have changed which may have changed images that are in view self.server.controller.check_images_in_view() - await self.update_transformed_images(dataset_ids, visible=visible) + await self.update_transformed_images( + dataset_ids, predictions_original_images, visible=visible + ) async def _chunk_update_images(self, dataset_ids, visible=False): ids = list(dataset_ids) - for i in range(0, len(ids), UPDATE_IMAGES_CHUNK_SIZE): chunk = ids[i : i + UPDATE_IMAGES_CHUNK_SIZE] await self._update_images(chunk, visible=visible)