Skip to content

Commit 04d90fe

Browse files
author
Allen Wang
committed
Internal change
PiperOrigin-RevId: 306699912
1 parent e9a1025 commit 04d90fe

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

official/vision/image_classification/classifier_trainer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ def get_image_size_from_model(
101101
def _get_dataset_builders(params: base_configs.ExperimentConfig,
102102
strategy: tf.distribute.Strategy,
103103
one_hot: bool
104-
) -> Tuple[Any, Any, Any]:
105-
"""Create and return train, validation, and test dataset builders."""
104+
) -> Tuple[Any, Any]:
105+
"""Create and return train and validation dataset builders."""
106106
if one_hot:
107107
logging.warning('label_smoothing > 0, so datasets will be one hot encoded.')
108108
else:

official/vision/image_classification/dataset_factory.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ class DatasetConfig(base_config.Config):
116116
num_channels: Union[int, str] = 'infer'
117117
num_examples: Union[int, str] = 'infer'
118118
batch_size: int = 128
119-
use_per_replica_batch_size: bool = False
119+
use_per_replica_batch_size: bool = True
120120
num_devices: int = 1
121121
dtype: str = 'float32'
122122
one_hot: bool = True
@@ -185,14 +185,14 @@ def is_training(self) -> bool:
185185
def batch_size(self) -> int:
186186
"""The batch size, multiplied by the number of replicas (if configured)."""
187187
if self.config.use_per_replica_batch_size:
188-
return self.global_batch_size
188+
return self.config.batch_size * self.config.num_devices
189189
else:
190190
return self.config.batch_size
191191

192192
@property
193193
def global_batch_size(self):
194194
"""The global batch size across all replicas."""
195-
return self.config.batch_size * self.config.num_devices
195+
return self.batch_size
196196

197197
@property
198198
def num_steps(self) -> int:

0 commit comments

Comments
 (0)