Skip to content

Commit

Permalink
computing embedding only for missing rows
Browse files Browse the repository at this point in the history
  • Loading branch information
fjxmlzn committed Dec 27, 2024
1 parent bbd6c54 commit 7080993
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 6 deletions.
13 changes: 13 additions & 0 deletions pe/embedding/embedding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod

from pe.constant.data import EMBEDDING_COLUMN_NAME
from pe.data import Data


class Embedding(ABC):
Expand All @@ -19,3 +20,15 @@ def compute_embedding(self, data):
:type data: :py:class:`pe.data.data.Data`
"""
...

def filter_uncomputed_rows(self, data):
data_frame = data.data_frame
if self.column_name in data_frame.columns:
data_frame = data_frame[data_frame[self.column_name].isna()]
return Data(data_frame=data_frame, metadata=data.metadata)

def merge_computed_rows(self, data, computed_data):
data_frame = data.data_frame
computed_data_frame = computed_data.data_frame
data_frame.update(computed_data_frame)
return Data(data_frame=data_frame, metadata=data.metadata)
21 changes: 15 additions & 6 deletions pe/embedding/image/inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,15 @@ def compute_embedding(self, data):
:return: The data object with the computed embedding
:rtype: :py:class:`pe.data.data.Data`
"""
if self.column_name in data.data_frame.columns:
uncomputed_data = self.filter_uncomputed_rows(data)
if len(uncomputed_data.data_frame) == 0:
execution_logger.info(f"Embedding: {self.column_name} already computed")
return data
execution_logger.info(f"Embedding: computing {self.column_name} for {len(data.data_frame)} samples")
x = np.stack(data.data_frame[IMAGE_DATA_COLUMN_NAME].values, axis=0)
execution_logger.info(
f"Embedding: computing {self.column_name} for {len(uncomputed_data.data_frame)}/{len(data.data_frame)}"
" samples"
)
x = np.stack(uncomputed_data.data_frame[IMAGE_DATA_COLUMN_NAME].values, axis=0)
if x.shape[3] == 1:
x = np.repeat(x, 3, axis=3)
embeddings = []
Expand All @@ -74,6 +78,11 @@ def compute_embedding(self, data):
embeddings.append(self._inception(torch.from_numpy(transformed_x).to(self._device)))
embeddings = torch.cat(embeddings, dim=0)
embeddings = embeddings.cpu().detach().numpy()
data.data_frame[self.column_name] = pd.Series(list(embeddings), index=data.data_frame.index)
execution_logger.info(f"Embedding: finished computing {self.column_name} for {len(data.data_frame)} samples")
return data
uncomputed_data.data_frame[self.column_name] = pd.Series(
list(embeddings), index=uncomputed_data.data_frame.index
)
execution_logger.info(
f"Embedding: finished computing {self.column_name} for "
f"{len(uncomputed_data.data_frame)}/{len(data.data_frame)} samples"
)
return self.merge_computed_rows(data, uncomputed_data)

0 comments on commit 7080993

Please sign in to comment.