File tree 2 files changed +5
-5
lines changed
official/vision/image_classification
2 files changed +5
-5
lines changed Original file line number Diff line number Diff line change @@ -101,8 +101,8 @@ def get_image_size_from_model(
101
101
def _get_dataset_builders (params : base_configs .ExperimentConfig ,
102
102
strategy : tf .distribute .Strategy ,
103
103
one_hot : bool
104
- ) -> Tuple [Any , Any , Any ]:
105
- """Create and return train, validation, and test dataset builders."""
104
+ ) -> Tuple [Any , Any ]:
105
+ """Create and return train and validation dataset builders."""
106
106
if one_hot :
107
107
logging .warning ('label_smoothing > 0, so datasets will be one hot encoded.' )
108
108
else :
Original file line number Diff line number Diff line change @@ -116,7 +116,7 @@ class DatasetConfig(base_config.Config):
116
116
num_channels : Union [int , str ] = 'infer'
117
117
num_examples : Union [int , str ] = 'infer'
118
118
batch_size : int = 128
119
- use_per_replica_batch_size : bool = False
119
+ use_per_replica_batch_size : bool = True
120
120
num_devices : int = 1
121
121
dtype : str = 'float32'
122
122
one_hot : bool = True
@@ -185,14 +185,14 @@ def is_training(self) -> bool:
185
185
def batch_size (self ) -> int :
186
186
"""The batch size, multiplied by the number of replicas (if configured)."""
187
187
if self .config .use_per_replica_batch_size :
188
- return self .global_batch_size
188
+ return self .config . batch_size * self . config . num_devices
189
189
else :
190
190
return self .config .batch_size
191
191
192
192
@property
193
193
def global_batch_size (self ):
194
194
"""The global batch size across all replicas."""
195
- return self .config . batch_size * self . config . num_devices
195
+ return self .batch_size
196
196
197
197
@property
198
198
def num_steps (self ) -> int :
You can’t perform that action at this time.
0 commit comments