diff --git a/pe/embedding/embedding.py b/pe/embedding/embedding.py index cfc3827..54fadbd 100644 --- a/pe/embedding/embedding.py +++ b/pe/embedding/embedding.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from pe.constant.data import EMBEDDING_COLUMN_NAME +from pe.data import Data class Embedding(ABC): @@ -19,3 +20,15 @@ def compute_embedding(self, data): :type data: :py:class:`pe.data.data.Data` """ ... + + def filter_uncomputed_rows(self, data): + data_frame = data.data_frame + if self.column_name in data_frame.columns: + data_frame = data_frame[data_frame[self.column_name].isna()] + return Data(data_frame=data_frame, metadata=data.metadata) + + def merge_computed_rows(self, data, computed_data): + data_frame = data.data_frame + computed_data_frame = computed_data.data_frame + data_frame.update(computed_data_frame) + return Data(data_frame=data_frame, metadata=data.metadata) diff --git a/pe/embedding/image/inception.py b/pe/embedding/image/inception.py index c6a8e67..175b319 100644 --- a/pe/embedding/image/inception.py +++ b/pe/embedding/image/inception.py @@ -54,11 +54,15 @@ def compute_embedding(self, data): :return: The data object with the computed embedding :rtype: :py:class:`pe.data.data.Data` """ - if self.column_name in data.data_frame.columns: + uncomputed_data = self.filter_uncomputed_rows(data) + if len(uncomputed_data.data_frame) == 0: execution_logger.info(f"Embedding: {self.column_name} already computed") return data - execution_logger.info(f"Embedding: computing {self.column_name} for {len(data.data_frame)} samples") - x = np.stack(data.data_frame[IMAGE_DATA_COLUMN_NAME].values, axis=0) + execution_logger.info( + f"Embedding: computing {self.column_name} for {len(uncomputed_data.data_frame)}/{len(data.data_frame)}" + " samples" + ) + x = np.stack(uncomputed_data.data_frame[IMAGE_DATA_COLUMN_NAME].values, axis=0) if x.shape[3] == 1: x = np.repeat(x, 3, axis=3) embeddings = [] @@ -74,6 +78,11 @@ def compute_embedding(self, data): embeddings.append(self._inception(torch.from_numpy(transformed_x).to(self._device))) embeddings = torch.cat(embeddings, dim=0) embeddings = embeddings.cpu().detach().numpy() - data.data_frame[self.column_name] = pd.Series(list(embeddings), index=data.data_frame.index) - execution_logger.info(f"Embedding: finished computing {self.column_name} for {len(data.data_frame)} samples") - return data + uncomputed_data.data_frame[self.column_name] = pd.Series( + list(embeddings), index=uncomputed_data.data_frame.index + ) + execution_logger.info( + f"Embedding: finished computing {self.column_name} for " + f"{len(uncomputed_data.data_frame)}/{len(data.data_frame)} samples" + ) + return self.merge_computed_rows(data, uncomputed_data)