Skip to content

Commit

Permalink
feat(dataset): add download CLI argument for HF datasets
Browse files Browse the repository at this point in the history
Defaulting to streaming download for HF datasets, but now you can specify
`--download` to download the dataset to disk before loading it.
  • Loading branch information
PaulHax committed Nov 12, 2024
1 parent 021122d commit 2aaf3fc
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 11 deletions.
13 changes: 11 additions & 2 deletions src/nrtk_explorer/app/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,20 @@ def __init__(self, server=None):
"--dataset",
nargs="+",
default=DEFAULT_DATASETS,
help="Path of the json file describing the image dataset",
help="Path to the JSON file describing the image dataset",
)

self.server.cli.add_argument(
"--download",
action="store_true",
default=False,
help="Download Hugging Face Hub datasets instead of streaming them directly",
)

known_args, _ = self.server.cli.parse_known_args()
dataset_identifiers = expand_hugging_face_datasets(known_args.dataset)
dataset_identifiers = expand_hugging_face_datasets(
known_args.dataset, not known_args.download
)
self.input_paths = dataset_identifiers
self.state.current_dataset = self.input_paths[0]

Expand Down
17 changes: 8 additions & 9 deletions src/nrtk_explorer/library/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
ClassLabel,
)

HF_ROWS_MAX_TO_DOWNLOAD = 5000
HF_ROWS_TO_TAKE_STREAMING = 600
HF_ROWS_TO_TAKE_STREAMING = 300


class BaseDataset:
Expand Down Expand Up @@ -72,7 +71,7 @@ def is_coco_dataset(path: str):
return all(key in content for key in required_keys)


def expand_hugging_face_datasets(dataset_identifiers: SequenceType[str]):
def expand_hugging_face_datasets(dataset_identifiers: SequenceType[str], streaming=True):
expanded_identifiers = []
for identifier in dataset_identifiers:
if is_coco_dataset(identifier):
Expand All @@ -81,19 +80,19 @@ def expand_hugging_face_datasets(dataset_identifiers: SequenceType[str]):
infos = get_dataset_infos(identifier)
for config_name, info in infos.items():
for split_name in info.splits:
expanded_identifiers.append(f"{identifier}@{config_name}@{split_name}")
streaming_str = "streaming" if streaming else "download"
expanded_identifiers.append(
f"{identifier}@{config_name}@{split_name}@{streaming_str}"
)
return expanded_identifiers


class HuggingFaceDataset(BaseDataset):
"""Interface for Hugging Face datasets with a similar API to JsonDataset."""

def __init__(self, identifier: str):
repo, config, split = identifier.split("@")
infos = get_dataset_infos(repo)[config]
split_info = infos.splits[split]
num_examples = split_info.num_examples if hasattr(split_info, "num_examples") else None
self._streaming = num_examples is None or num_examples > HF_ROWS_MAX_TO_DOWNLOAD
repo, config, split, streaming = identifier.split("@")
self._streaming = streaming == "streaming"
self._dataset = load_dataset(repo, config, split=split, streaming=self._streaming)
if self._streaming:
self._dataset = self._dataset.take(HF_ROWS_TO_TAKE_STREAMING)
Expand Down

0 comments on commit 2aaf3fc

Please sign in to comment.