diff --git a/src/nrtk_explorer/app/embeddings.py b/src/nrtk_explorer/app/embeddings.py index b481075..839806e 100644 --- a/src/nrtk_explorer/app/embeddings.py +++ b/src/nrtk_explorer/app/embeddings.py @@ -1,8 +1,10 @@ +from typing import Dict +from trame.decorators import TrameApp, change +from PIL import Image from nrtk_explorer.widgets.nrtk_explorer import ScatterPlot from nrtk_explorer.library import embeddings_extractor from nrtk_explorer.library import dimension_reducers from nrtk_explorer.library.dataset import get_dataset -from nrtk_explorer.library.scoring import partition from nrtk_explorer.app.applet import Applet from nrtk_explorer.app.images.image_ids import ( @@ -20,6 +22,41 @@ from trame.app import get_server, asynchronous +IdToImage = Dict[str, Image.Image] + + +@TrameApp() +class TransformedImages: + def __init__(self, server): + self.server = server + self.transformed_images: IdToImage = {} + + def emit_update(self): + self.server.controller.update_transformed_images(self.transformed_images) + + def add_images(self, dataset_id_to_image: IdToImage): + self.transformed_images.update(dataset_id_to_image) + self.emit_update() + + @change("dataset_ids") + def on_dataset_ids(self, **kwargs): + self.transformed_images = { + k: v + for k, v in self.transformed_images.items() + if image_id_to_dataset_id(k) in self.server.state.dataset_ids + } + self.emit_update() + + @change("current_dataset") + def on_dataset(self, **kwargs): + self.transformed_images = {} + self.emit_update() + + def clear(self, **kwargs): + self.transformed_images = {} + self.emit_update() + + class EmbeddingsApp(Applet): def __init__( self, @@ -53,15 +90,21 @@ def __init__( "id": "", "is_transformed": True, } + self.state.dimensionality = "3" - def on_server_ready(self, *args, **kwargs): + self.clear_points_transformations() # init vars self.on_feature_extraction_model_change() - self.state.change("feature_extraction_model")(self.on_feature_extraction_model_change) + self.transformed_images = TransformedImages(server) + self.server.controller.update_transformed_images.add(self.update_transformed_images) + + def on_server_ready(self, *args, **kwargs): + self.state.change("feature_extraction_model")(self.on_feature_extraction_model_change) + self.save_embedding_params() self.update_points() self.state.change("dataset_ids")(self.update_points) - self.server.controller.apply_transform.add(self.clear_points_transformations) + self.server.controller.apply_transform.add(self.transformed_images.clear) self.state.change("transform_enabled_switch")(self.update_points_transformations_state) def on_feature_extraction_model_change(self, **kwargs): @@ -75,29 +118,31 @@ def compute_points(self, fit_features, features): # reduce will fail if no features return [] - if self.state.tab == "PCA": + params = self.embedding_params + + if params["tab"] == "PCA": return self.reducer.reduce( name="PCA", fit_features=fit_features, features=features, - dims=self.state.dimensionality, - whiten=self.state.pca_whiten, - solver=self.state.pca_solver, + dims=params["dimensionality"], + whiten=params["pca_whiten"], + solver=params["pca_solver"], ) # must be UMAP args = {} - if self.state.umap_random_seed: - args["random_state"] = int(self.state.umap_random_seed_value) + if params["umap_random_seed"]: + args["random_state"] = int(params["umap_random_seed_value"]) - if self.state.umap_n_neighbors: - args["n_neighbors"] = int(self.state.umap_n_neighbors_number) + if params["umap_n_neighbors"]: + args["n_neighbors"] = int(params["umap_n_neighbors_number"]) return self.reducer.reduce( name="UMAP", fit_features=fit_features, features=features, - dims=self.state.dimensionality, + dims=params["dimensionality"], **args, ) @@ -111,14 +156,7 @@ def update_points_transformations_state(self, **kwargs): else: self.state.points_transformations = {} - async def compute_source_points(self): - with self.state: - self.state.is_loading = True - self.clear_points_transformations() - - # Don't lock server before enabling the spinner on client - await self.server.network_completion - + def compute_source_points(self): images = [ self.images.get_image_without_cache_eviction(id) for id in self.state.dataset_ids ] @@ -135,36 +173,65 @@ async def compute_source_points(self): self.state.camera_position = [] + async def _update_points(self): + with self.state: + self.state.is_loading = True + self.points_sources = {} + self.clear_points_transformations() + # Don't lock server before enabling the spinner on client + await self.server.network_completion + + self.save_embedding_params() + with self.state: + self.compute_source_points() + self.update_transformed_images(self.transformed_images.transformed_images) self.state.is_loading = False def update_points(self, **kwargs): if hasattr(self, "_update_task"): self._update_task.cancel() - self._update_task = asynchronous.create_task(self.compute_source_points()) + self._update_task = asynchronous.create_task(self._update_points()) + + def save_embedding_params(self): + self.embedding_params = { + "tab": self.state.tab, + "dimensionality": self.state.dimensionality, + "pca_whiten": self.state.pca_whiten, + "pca_solver": self.state.pca_solver, + "umap_random_seed": self.state.umap_random_seed, + "umap_random_seed_value": self.state.umap_random_seed_value, + "umap_n_neighbors": self.state.umap_n_neighbors, + "umap_n_neighbors_number": self.state.umap_n_neighbors_number, + } def on_run_clicked(self): + self.save_embedding_params() self.update_points() def on_run_transformations(self, id_to_image): - hits, misses = partition( - lambda id: image_id_to_dataset_id(id) in self._stashed_points_transformations, - id_to_image.keys(), - ) + self.transformed_images.add_images(id_to_image) + + def update_transformed_images(self, id_to_image): + new_to_plot = { + id: img + for id, img in id_to_image.items() + if image_id_to_dataset_id(id) not in self._stashed_points_transformations + } - to_plot = {id: id_to_image[id] for id in misses} transformation_features = self.extractor.extract( - list(to_plot.values()), + list(new_to_plot.values()), batch_size=int(self.state.model_batch_size), ) points = self.compute_points(self.features, transformation_features) - ids_to_points = zip(to_plot.keys(), points) + image_id_to_point = zip(new_to_plot.keys(), points) - updated_points = {image_id_to_dataset_id(id): point for id, point in ids_to_points} + updated_points = {image_id_to_dataset_id(id): point for id, point in image_id_to_point} self._stashed_points_transformations = { **self._stashed_points_transformations, **updated_points, } + self.update_points_transformations_state() # called by category filter