diff --git a/test/test_train_mp_imagenet.py b/test/test_train_mp_imagenet.py index c2622823642..7a7a1300bb3 100644 --- a/test/test_train_mp_imagenet.py +++ b/test/test_train_mp_imagenet.py @@ -24,6 +24,9 @@ '--lr_scheduler_divisor': { 'type': int, }, + '--test_only_at_end': { + 'action': 'store_true', + }, } FLAGS = args_parse.parse_common_options( @@ -236,15 +239,16 @@ def test_loop_fn(loader, epoch): xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now())) train_loop_fn(train_device_loader, epoch) xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now())) - accuracy = test_loop_fn(test_device_loader, epoch) - xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format( - epoch, test_utils.now(), accuracy)) - max_accuracy = max(accuracy, max_accuracy) - test_utils.write_to_summary( - writer, - epoch, - dict_to_write={'Accuracy/test': accuracy}, - write_xla_metrics=True) + if not FLAGS.test_only_at_end or epoch == FLAGS.num_epochs: + accuracy = test_loop_fn(test_device_loader, epoch) + xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format( + epoch, test_utils.now(), accuracy)) + max_accuracy = max(accuracy, max_accuracy) + test_utils.write_to_summary( + writer, + epoch, + dict_to_write={'Accuracy/test': accuracy}, + write_xla_metrics=True) if FLAGS.metrics_debug: xm.master_print(met.metrics_report())