1
+ from typing import Dict
2
+ from trame .decorators import TrameApp , change
3
+ from PIL import Image
1
4
from nrtk_explorer .widgets .nrtk_explorer import ScatterPlot
2
5
from nrtk_explorer .library import embeddings_extractor
3
6
from nrtk_explorer .library import dimension_reducers
4
7
from nrtk_explorer .library .dataset import get_dataset
5
- from nrtk_explorer .library .scoring import partition
6
8
from nrtk_explorer .app .applet import Applet
7
9
8
10
from nrtk_explorer .app .images .image_ids import (
20
22
from trame .app import get_server , asynchronous
21
23
22
24
25
+ IdToImage = Dict [str , Image .Image ]
26
+
27
+
28
+ @TrameApp ()
29
+ class TransformedImages :
30
+ def __init__ (self , server ):
31
+ self .server = server
32
+ self .transformed_images : IdToImage = {}
33
+
34
+ def emit_update (self ):
35
+ self .server .controller .update_transformed_images (self .transformed_images )
36
+
37
+ def add_images (self , dataset_id_to_image : IdToImage ):
38
+ self .transformed_images .update (dataset_id_to_image )
39
+ self .emit_update ()
40
+
41
+ @change ("dataset_ids" )
42
+ def on_dataset_ids (self , ** kwargs ):
43
+ self .transformed_images = {
44
+ k : v
45
+ for k , v in self .transformed_images .items ()
46
+ if image_id_to_dataset_id (k ) in self .server .state .dataset_ids
47
+ }
48
+ self .emit_update ()
49
+
50
+ @change ("current_dataset" )
51
+ def on_dataset (self , ** kwargs ):
52
+ self .transformed_images = {}
53
+ self .emit_update ()
54
+
55
+ def clear (self , ** kwargs ):
56
+ self .transformed_images = {}
57
+ self .emit_update ()
58
+
59
+
23
60
class EmbeddingsApp (Applet ):
24
61
def __init__ (
25
62
self ,
@@ -53,15 +90,21 @@ def __init__(
53
90
"id" : "" ,
54
91
"is_transformed" : True ,
55
92
}
93
+ self .state .dimensionality = "3"
56
94
57
- def on_server_ready ( self , * args , ** kwargs ):
95
+ self . clear_points_transformations () # init vars
58
96
self .on_feature_extraction_model_change ()
59
- self .state .change ("feature_extraction_model" )(self .on_feature_extraction_model_change )
60
97
98
+ self .transformed_images = TransformedImages (server )
99
+ self .server .controller .update_transformed_images .add (self .update_transformed_images )
100
+
101
+ def on_server_ready (self , * args , ** kwargs ):
102
+ self .state .change ("feature_extraction_model" )(self .on_feature_extraction_model_change )
103
+ self .save_embedding_params ()
61
104
self .update_points ()
62
105
self .state .change ("dataset_ids" )(self .update_points )
63
-
64
106
self .server .controller .apply_transform .add (self .clear_points_transformations )
107
+ self .server .controller .apply_transform .add (self .transformed_images .clear )
65
108
self .state .change ("transform_enabled_switch" )(self .update_points_transformations_state )
66
109
67
110
def on_feature_extraction_model_change (self , ** kwargs ):
@@ -75,29 +118,31 @@ def compute_points(self, fit_features, features):
75
118
# reduce will fail if no features
76
119
return []
77
120
78
- if self .state .tab == "PCA" :
121
+ params = self .embedding_params
122
+
123
+ if params ["tab" ] == "PCA" :
79
124
return self .reducer .reduce (
80
125
name = "PCA" ,
81
126
fit_features = fit_features ,
82
127
features = features ,
83
- dims = self . state . dimensionality ,
84
- whiten = self . state . pca_whiten ,
85
- solver = self . state . pca_solver ,
128
+ dims = params [ " dimensionality" ] ,
129
+ whiten = params [ " pca_whiten" ] ,
130
+ solver = params [ " pca_solver" ] ,
86
131
)
87
132
88
133
# must be UMAP
89
134
args = {}
90
- if self . state . umap_random_seed :
91
- args ["random_state" ] = int (self . state . umap_random_seed_value )
135
+ if params [ " umap_random_seed" ] :
136
+ args ["random_state" ] = int (params [ " umap_random_seed_value" ] )
92
137
93
- if self . state . umap_n_neighbors :
94
- args ["n_neighbors" ] = int (self . state . umap_n_neighbors_number )
138
+ if params [ " umap_n_neighbors" ] :
139
+ args ["n_neighbors" ] = int (params [ " umap_n_neighbors_number" ] )
95
140
96
141
return self .reducer .reduce (
97
142
name = "UMAP" ,
98
143
fit_features = fit_features ,
99
144
features = features ,
100
- dims = self . state . dimensionality ,
145
+ dims = params [ " dimensionality" ] ,
101
146
** args ,
102
147
)
103
148
@@ -111,14 +156,7 @@ def update_points_transformations_state(self, **kwargs):
111
156
else :
112
157
self .state .points_transformations = {}
113
158
114
- async def compute_source_points (self ):
115
- with self .state :
116
- self .state .is_loading = True
117
- self .clear_points_transformations ()
118
-
119
- # Don't lock server before enabling the spinner on client
120
- await self .server .network_completion
121
-
159
+ def compute_source_points (self ):
122
160
images = [
123
161
self .images .get_image_without_cache_eviction (id ) for id in self .state .dataset_ids
124
162
]
@@ -135,36 +173,65 @@ async def compute_source_points(self):
135
173
136
174
self .state .camera_position = []
137
175
176
+ async def _update_points (self ):
177
+ with self .state :
178
+ self .state .is_loading = True
179
+ self .points_sources = {}
180
+ self .clear_points_transformations ()
181
+ # Don't lock server before enabling the spinner on client
182
+ await self .server .network_completion
183
+
184
+ self .save_embedding_params ()
185
+
138
186
with self .state :
187
+ self .compute_source_points ()
188
+ self .update_transformed_images (self .transformed_images .transformed_images )
139
189
self .state .is_loading = False
140
190
141
191
def update_points (self , ** kwargs ):
142
192
if hasattr (self , "_update_task" ):
143
193
self ._update_task .cancel ()
144
- self ._update_task = asynchronous .create_task (self .compute_source_points ())
194
+ self ._update_task = asynchronous .create_task (self ._update_points ())
195
+
196
+ def save_embedding_params (self ):
197
+ self .embedding_params = {
198
+ "tab" : self .state .tab ,
199
+ "dimensionality" : self .state .dimensionality ,
200
+ "pca_whiten" : self .state .pca_whiten ,
201
+ "pca_solver" : self .state .pca_solver ,
202
+ "umap_random_seed" : self .state .umap_random_seed ,
203
+ "umap_random_seed_value" : self .state .umap_random_seed_value ,
204
+ "umap_n_neighbors" : self .state .umap_n_neighbors ,
205
+ "umap_n_neighbors_number" : self .state .umap_n_neighbors_number ,
206
+ }
145
207
146
208
def on_run_clicked (self ):
209
+ self .save_embedding_params ()
147
210
self .update_points ()
148
211
149
212
def on_run_transformations (self , id_to_image ):
150
- hits , misses = partition (
151
- lambda id : image_id_to_dataset_id (id ) in self ._stashed_points_transformations ,
152
- id_to_image .keys (),
153
- )
213
+ self .transformed_images .add_images (id_to_image )
214
+
215
+ def update_transformed_images (self , id_to_image ):
216
+ new_to_plot = {
217
+ id : img
218
+ for id , img in id_to_image .items ()
219
+ if image_id_to_dataset_id (id ) not in self ._stashed_points_transformations
220
+ }
154
221
155
- to_plot = {id : id_to_image [id ] for id in misses }
156
222
transformation_features = self .extractor .extract (
157
- list (to_plot .values ()),
223
+ list (new_to_plot .values ()),
158
224
batch_size = int (self .state .model_batch_size ),
159
225
)
160
226
points = self .compute_points (self .features , transformation_features )
161
- ids_to_points = zip (to_plot .keys (), points )
227
+ image_id_to_point = zip (new_to_plot .keys (), points )
162
228
163
- updated_points = {image_id_to_dataset_id (id ): point for id , point in ids_to_points }
229
+ updated_points = {image_id_to_dataset_id (id ): point for id , point in image_id_to_point }
164
230
self ._stashed_points_transformations = {
165
231
** self ._stashed_points_transformations ,
166
232
** updated_points ,
167
233
}
234
+
168
235
self .update_points_transformations_state ()
169
236
170
237
# called by category filter
0 commit comments