File tree 3 files changed +12
-1
lines changed
vision/beta/configs/experiments/image_classification
3 files changed +12
-1
lines changed Original file line number Diff line number Diff line change @@ -75,6 +75,10 @@ class DataConfig(base_config.Config):
75
75
features. The main use case is to skip the image/video decoding for better
76
76
performance.
77
77
seed: An optional seed to use for deterministic shuffling/preprocessing.
78
+ prefetch_buffer_size: An int specifying the buffer size of prefetch
79
+ datasets. If None, the buffer size is autotuned. Specifying this is useful
80
+ in case autotuning uses up too much memory by making the buffer size too
81
+ high.
78
82
"""
79
83
input_path : Union [Sequence [str ], str , base_config .Config ] = ""
80
84
tfds_name : str = ""
@@ -95,6 +99,7 @@ class DataConfig(base_config.Config):
95
99
tfds_as_supervised : bool = False
96
100
tfds_skip_decoding_feature : str = ""
97
101
seed : Optional [int ] = None
102
+ prefetch_buffer_size : Optional [int ] = None
98
103
99
104
100
105
@dataclasses .dataclass
Original file line number Diff line number Diff line change @@ -270,6 +270,8 @@ def __init__(self,
270
270
self ._transform_and_batch_fn = transform_and_batch_fn
271
271
self ._postprocess_fn = postprocess_fn
272
272
self ._seed = params .seed
273
+ self ._prefetch_buffer_size = (params .prefetch_buffer_size or
274
+ tf .data .experimental .AUTOTUNE )
273
275
274
276
# When tf.data service is enabled, each data service worker should get
275
277
# different random seeds. Thus, we set `seed` to None.
@@ -475,4 +477,4 @@ def read(self,
475
477
options = tf .data .Options ()
476
478
options .experimental_deterministic = self ._deterministic
477
479
dataset = dataset .with_options (options )
478
- return dataset .prefetch (tf . data . experimental . AUTOTUNE )
480
+ return dataset .prefetch (self . _prefetch_buffer_size )
Original file line number Diff line number Diff line change @@ -19,12 +19,16 @@ task:
19
19
is_training : true
20
20
global_batch_size : 2048
21
21
dtype : ' float16'
22
+ # Autotuning the prefetch buffer size causes OOMs, so set it to a reasonable
23
+ # static value: 32. See b/218880025.
24
+ prefetch_buffer_size : 32
22
25
validation_data :
23
26
input_path : ' imagenet-2012-tfrecord/valid*'
24
27
is_training : false
25
28
global_batch_size : 2048
26
29
dtype : ' float16'
27
30
drop_remainder : false
31
+ prefetch_buffer_size : 32
28
32
trainer :
29
33
train_steps : 56160
30
34
validation_steps : 25
You can’t perform that action at this time.
0 commit comments