Skip to content

Commit

Permalink
perf(dataset): convert image mode to RGB on HF dataset load
Browse files Browse the repository at this point in the history
  • Loading branch information
PaulHax committed Nov 13, 2024
1 parent cd0123c commit 2134e0d
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/nrtk_explorer/library/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
get_dataset_infos,
Sequence as SequenceDataset,
ClassLabel,
Image as DatasetImage,
)

HF_ROWS_TO_TAKE_STREAMING = 300
Expand Down Expand Up @@ -101,6 +102,8 @@ def __init__(self, identifier: str):
repo, config, split, streaming = identifier.split("@")
self._streaming = streaming == "streaming"
self._dataset = load_dataset(repo, config, split=split, streaming=self._streaming)
# transforms and base64 encoding require RGB mode
self._dataset.cast_column("image", DatasetImage(mode="RGB"))
if self._streaming:
self._dataset = self._dataset.take(HF_ROWS_TO_TAKE_STREAMING)
self.imgs: dict[str, dict] = {}
Expand All @@ -124,7 +127,7 @@ def extract_labels(feature):
return feature.names
if isinstance(feature, SequenceDataset):
return extract_labels(feature.feature)
if isinstance(feature, list):
if isinstance(feature, list) and len(feature) >= 1:
return extract_labels(feature[0])
if isinstance(feature, dict):
for key in ["category", "category_id", "label", "labels", "objects"]:
Expand All @@ -139,6 +142,7 @@ def extract_labels(feature):
self.cats = {i: {"id": i, "name": str(name)} for i, name in enumerate(labels)}

new_cats = set()
# speed initial metadata process by not loading images if we can random access rows (not streaming)
maybe_no_image = (
self._dataset if self._streaming else self._dataset.remove_columns(["image"])
)
Expand Down

0 comments on commit 2134e0d

Please sign in to comment.