Skip to content

Commit 227d5ef

Browse files
committed
fix(embeddings): save embeddings params on compute
and use saved params when ploting new transformed images. Also, cache transformed images to re-plot then when embedding params are changed. Closes #170 Closes #171
1 parent e095bd0 commit 227d5ef

File tree

1 file changed

+97
-30
lines changed

1 file changed

+97
-30
lines changed

src/nrtk_explorer/app/embeddings.py

+97-30
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
from typing import Dict
2+
from trame.decorators import TrameApp, change
3+
from PIL import Image
14
from nrtk_explorer.widgets.nrtk_explorer import ScatterPlot
25
from nrtk_explorer.library import embeddings_extractor
36
from nrtk_explorer.library import dimension_reducers
47
from nrtk_explorer.library.dataset import get_dataset
5-
from nrtk_explorer.library.scoring import partition
68
from nrtk_explorer.app.applet import Applet
79

810
from nrtk_explorer.app.images.image_ids import (
@@ -20,6 +22,41 @@
2022
from trame.app import get_server, asynchronous
2123

2224

25+
IdToImage = Dict[str, Image.Image]
26+
27+
28+
@TrameApp()
29+
class TransformedImages:
30+
def __init__(self, server):
31+
self.server = server
32+
self.transformed_images: IdToImage = {}
33+
34+
def emit_update(self):
35+
self.server.controller.update_transformed_images(self.transformed_images)
36+
37+
def add_images(self, dataset_id_to_image: IdToImage):
38+
self.transformed_images.update(dataset_id_to_image)
39+
self.emit_update()
40+
41+
@change("dataset_ids")
42+
def on_dataset_ids(self, **kwargs):
43+
self.transformed_images = {
44+
k: v
45+
for k, v in self.transformed_images.items()
46+
if image_id_to_dataset_id(k) in self.server.state.dataset_ids
47+
}
48+
self.emit_update()
49+
50+
@change("current_dataset")
51+
def on_dataset(self, **kwargs):
52+
self.transformed_images = {}
53+
self.emit_update()
54+
55+
def clear(self, **kwargs):
56+
self.transformed_images = {}
57+
self.emit_update()
58+
59+
2360
class EmbeddingsApp(Applet):
2461
def __init__(
2562
self,
@@ -53,15 +90,21 @@ def __init__(
5390
"id": "",
5491
"is_transformed": True,
5592
}
93+
self.state.dimensionality = "3"
5694

57-
def on_server_ready(self, *args, **kwargs):
95+
self.clear_points_transformations() # init vars
5896
self.on_feature_extraction_model_change()
59-
self.state.change("feature_extraction_model")(self.on_feature_extraction_model_change)
6097

98+
self.transformed_images = TransformedImages(server)
99+
self.server.controller.update_transformed_images.add(self.update_transformed_images)
100+
101+
def on_server_ready(self, *args, **kwargs):
102+
self.state.change("feature_extraction_model")(self.on_feature_extraction_model_change)
103+
self.save_embedding_params()
61104
self.update_points()
62105
self.state.change("dataset_ids")(self.update_points)
63-
64106
self.server.controller.apply_transform.add(self.clear_points_transformations)
107+
self.server.controller.apply_transform.add(self.transformed_images.clear)
65108
self.state.change("transform_enabled_switch")(self.update_points_transformations_state)
66109

67110
def on_feature_extraction_model_change(self, **kwargs):
@@ -75,29 +118,31 @@ def compute_points(self, fit_features, features):
75118
# reduce will fail if no features
76119
return []
77120

78-
if self.state.tab == "PCA":
121+
params = self.embedding_params
122+
123+
if params["tab"] == "PCA":
79124
return self.reducer.reduce(
80125
name="PCA",
81126
fit_features=fit_features,
82127
features=features,
83-
dims=self.state.dimensionality,
84-
whiten=self.state.pca_whiten,
85-
solver=self.state.pca_solver,
128+
dims=params["dimensionality"],
129+
whiten=params["pca_whiten"],
130+
solver=params["pca_solver"],
86131
)
87132

88133
# must be UMAP
89134
args = {}
90-
if self.state.umap_random_seed:
91-
args["random_state"] = int(self.state.umap_random_seed_value)
135+
if params["umap_random_seed"]:
136+
args["random_state"] = int(params["umap_random_seed_value"])
92137

93-
if self.state.umap_n_neighbors:
94-
args["n_neighbors"] = int(self.state.umap_n_neighbors_number)
138+
if params["umap_n_neighbors"]:
139+
args["n_neighbors"] = int(params["umap_n_neighbors_number"])
95140

96141
return self.reducer.reduce(
97142
name="UMAP",
98143
fit_features=fit_features,
99144
features=features,
100-
dims=self.state.dimensionality,
145+
dims=params["dimensionality"],
101146
**args,
102147
)
103148

@@ -111,14 +156,7 @@ def update_points_transformations_state(self, **kwargs):
111156
else:
112157
self.state.points_transformations = {}
113158

114-
async def compute_source_points(self):
115-
with self.state:
116-
self.state.is_loading = True
117-
self.clear_points_transformations()
118-
119-
# Don't lock server before enabling the spinner on client
120-
await self.server.network_completion
121-
159+
def compute_source_points(self):
122160
images = [
123161
self.images.get_image_without_cache_eviction(id) for id in self.state.dataset_ids
124162
]
@@ -135,36 +173,65 @@ async def compute_source_points(self):
135173

136174
self.state.camera_position = []
137175

176+
async def _update_points(self):
177+
with self.state:
178+
self.state.is_loading = True
179+
self.points_sources = {}
180+
self.clear_points_transformations()
181+
# Don't lock server before enabling the spinner on client
182+
await self.server.network_completion
183+
184+
self.save_embedding_params()
185+
138186
with self.state:
187+
self.compute_source_points()
188+
self.update_transformed_images(self.transformed_images.transformed_images)
139189
self.state.is_loading = False
140190

141191
def update_points(self, **kwargs):
142192
if hasattr(self, "_update_task"):
143193
self._update_task.cancel()
144-
self._update_task = asynchronous.create_task(self.compute_source_points())
194+
self._update_task = asynchronous.create_task(self._update_points())
195+
196+
def save_embedding_params(self):
197+
self.embedding_params = {
198+
"tab": self.state.tab,
199+
"dimensionality": self.state.dimensionality,
200+
"pca_whiten": self.state.pca_whiten,
201+
"pca_solver": self.state.pca_solver,
202+
"umap_random_seed": self.state.umap_random_seed,
203+
"umap_random_seed_value": self.state.umap_random_seed_value,
204+
"umap_n_neighbors": self.state.umap_n_neighbors,
205+
"umap_n_neighbors_number": self.state.umap_n_neighbors_number,
206+
}
145207

146208
def on_run_clicked(self):
209+
self.save_embedding_params()
147210
self.update_points()
148211

149212
def on_run_transformations(self, id_to_image):
150-
hits, misses = partition(
151-
lambda id: image_id_to_dataset_id(id) in self._stashed_points_transformations,
152-
id_to_image.keys(),
153-
)
213+
self.transformed_images.add_images(id_to_image)
214+
215+
def update_transformed_images(self, id_to_image):
216+
new_to_plot = {
217+
id: img
218+
for id, img in id_to_image.items()
219+
if image_id_to_dataset_id(id) not in self._stashed_points_transformations
220+
}
154221

155-
to_plot = {id: id_to_image[id] for id in misses}
156222
transformation_features = self.extractor.extract(
157-
list(to_plot.values()),
223+
list(new_to_plot.values()),
158224
batch_size=int(self.state.model_batch_size),
159225
)
160226
points = self.compute_points(self.features, transformation_features)
161-
ids_to_points = zip(to_plot.keys(), points)
227+
image_id_to_point = zip(new_to_plot.keys(), points)
162228

163-
updated_points = {image_id_to_dataset_id(id): point for id, point in ids_to_points}
229+
updated_points = {image_id_to_dataset_id(id): point for id, point in image_id_to_point}
164230
self._stashed_points_transformations = {
165231
**self._stashed_points_transformations,
166232
**updated_points,
167233
}
234+
168235
self.update_points_transformations_state()
169236

170237
# called by category filter

0 commit comments

Comments
 (0)