From 2aaf3fc78818bf76b27a4f5ff29233066f7e4753 Mon Sep 17 00:00:00 2001 From: Paul Elliott Date: Tue, 12 Nov 2024 10:27:32 -0500 Subject: [PATCH] feat(dataset): add download CLI argument for HF datasets Defaulting to streaming download for HF datasets, but now you can specify `--download` to download the dataset to disk before loading it. --- src/nrtk_explorer/app/core.py | 13 +++++++++++-- src/nrtk_explorer/library/dataset.py | 17 ++++++++--------- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/src/nrtk_explorer/app/core.py b/src/nrtk_explorer/app/core.py index dc35f18..e0d520c 100644 --- a/src/nrtk_explorer/app/core.py +++ b/src/nrtk_explorer/app/core.py @@ -47,11 +47,20 @@ def __init__(self, server=None): "--dataset", nargs="+", default=DEFAULT_DATASETS, - help="Path of the json file describing the image dataset", + help="Path to the JSON file describing the image dataset", + ) + + self.server.cli.add_argument( + "--download", + action="store_true", + default=False, + help="Download Hugging Face Hub datasets instead of streaming them directly", ) known_args, _ = self.server.cli.parse_known_args() - dataset_identifiers = expand_hugging_face_datasets(known_args.dataset) + dataset_identifiers = expand_hugging_face_datasets( + known_args.dataset, not known_args.download + ) self.input_paths = dataset_identifiers self.state.current_dataset = self.input_paths[0] diff --git a/src/nrtk_explorer/library/dataset.py b/src/nrtk_explorer/library/dataset.py index 4ec1977..8b48f03 100644 --- a/src/nrtk_explorer/library/dataset.py +++ b/src/nrtk_explorer/library/dataset.py @@ -18,8 +18,7 @@ ClassLabel, ) -HF_ROWS_MAX_TO_DOWNLOAD = 5000 -HF_ROWS_TO_TAKE_STREAMING = 600 +HF_ROWS_TO_TAKE_STREAMING = 300 class BaseDataset: @@ -72,7 +71,7 @@ def is_coco_dataset(path: str): return all(key in content for key in required_keys) -def expand_hugging_face_datasets(dataset_identifiers: SequenceType[str]): +def expand_hugging_face_datasets(dataset_identifiers: SequenceType[str], streaming=True): expanded_identifiers = [] for identifier in dataset_identifiers: if is_coco_dataset(identifier): @@ -81,7 +80,10 @@ def expand_hugging_face_datasets(dataset_identifiers: SequenceType[str]): infos = get_dataset_infos(identifier) for config_name, info in infos.items(): for split_name in info.splits: - expanded_identifiers.append(f"{identifier}@{config_name}@{split_name}") + streaming_str = "streaming" if streaming else "download" + expanded_identifiers.append( + f"{identifier}@{config_name}@{split_name}@{streaming_str}" + ) return expanded_identifiers @@ -89,11 +91,8 @@ class HuggingFaceDataset(BaseDataset): """Interface for Hugging Face datasets with a similar API to JsonDataset.""" def __init__(self, identifier: str): - repo, config, split = identifier.split("@") - infos = get_dataset_infos(repo)[config] - split_info = infos.splits[split] - num_examples = split_info.num_examples if hasattr(split_info, "num_examples") else None - self._streaming = num_examples is None or num_examples > HF_ROWS_MAX_TO_DOWNLOAD + repo, config, split, streaming = identifier.split("@") + self._streaming = streaming == "streaming" self._dataset = load_dataset(repo, config, split=split, streaming=self._streaming) if self._streaming: self._dataset = self._dataset.take(HF_ROWS_TO_TAKE_STREAMING)