Skip to content

Commit aa0191f

Browse files
reedwmfyangf
authored andcommitted
Internal change
PiperOrigin-RevId: 438897452
1 parent 6845172 commit aa0191f

File tree

3 files changed

+12
-1
lines changed

3 files changed

+12
-1
lines changed

official/core/config_definitions.py

+5
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,10 @@ class DataConfig(base_config.Config):
7575
features. The main use case is to skip the image/video decoding for better
7676
performance.
7777
seed: An optional seed to use for deterministic shuffling/preprocessing.
78+
prefetch_buffer_size: An int specifying the buffer size of prefetch
79+
datasets. If None, the buffer size is autotuned. Specifying this is useful
80+
in case autotuning uses up too much memory by making the buffer size too
81+
high.
7882
"""
7983
input_path: Union[Sequence[str], str, base_config.Config] = ""
8084
tfds_name: str = ""
@@ -95,6 +99,7 @@ class DataConfig(base_config.Config):
9599
tfds_as_supervised: bool = False
96100
tfds_skip_decoding_feature: str = ""
97101
seed: Optional[int] = None
102+
prefetch_buffer_size: Optional[int] = None
98103

99104

100105
@dataclasses.dataclass

official/core/input_reader.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,8 @@ def __init__(self,
270270
self._transform_and_batch_fn = transform_and_batch_fn
271271
self._postprocess_fn = postprocess_fn
272272
self._seed = params.seed
273+
self._prefetch_buffer_size = (params.prefetch_buffer_size or
274+
tf.data.experimental.AUTOTUNE)
273275

274276
# When tf.data service is enabled, each data service worker should get
275277
# different random seeds. Thus, we set `seed` to None.
@@ -475,4 +477,4 @@ def read(self,
475477
options = tf.data.Options()
476478
options.experimental_deterministic = self._deterministic
477479
dataset = dataset.with_options(options)
478-
return dataset.prefetch(tf.data.experimental.AUTOTUNE)
480+
return dataset.prefetch(self._prefetch_buffer_size)

official/vision/beta/configs/experiments/image_classification/imagenet_resnet50_gpu.yaml

+4
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,16 @@ task:
1919
is_training: true
2020
global_batch_size: 2048
2121
dtype: 'float16'
22+
# Autotuning the prefetch buffer size causes OOMs, so set it to a reasonable
23+
# static value: 32. See b/218880025.
24+
prefetch_buffer_size: 32
2225
validation_data:
2326
input_path: 'imagenet-2012-tfrecord/valid*'
2427
is_training: false
2528
global_batch_size: 2048
2629
dtype: 'float16'
2730
drop_remainder: false
31+
prefetch_buffer_size: 32
2832
trainer:
2933
train_steps: 56160
3034
validation_steps: 25

0 commit comments

Comments
 (0)