Skip to content

Commit 38934f9

Browse files
leo0519nv-kkudrynski
authored andcommitted
[RN50/Paddle] Remove export script and add INT8 feature (QAT + inference)
1 parent 9dd9fcb commit 38934f9

13 files changed

+373
-263
lines changed

PaddlePaddle/Classification/RN50v1.5/Dockerfile

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/paddlepaddle:23.09-py3
1+
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/paddlepaddle:23.12-py3
22
FROM ${FROM_IMAGE_NAME}
33

44
ADD requirements.txt /workspace/

PaddlePaddle/Classification/RN50v1.5/README.md

+235-115
Large diffs are not rendered by default.

PaddlePaddle/Classification/RN50v1.5/export_model.py

-75
This file was deleted.

PaddlePaddle/Classification/RN50v1.5/inference.py

+16-11
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030

3131
def init_predictor(args):
32-
infer_dir = args.trt_inference_dir
32+
infer_dir = args.inference_dir
3333
assert os.path.isdir(
3434
infer_dir), f'inference_dir = "{infer_dir}" is not a directory'
3535
pdiparams_path = glob.glob(os.path.join(infer_dir, '*.pdiparams'))
@@ -41,7 +41,7 @@ def init_predictor(args):
4141
predictor_config = Config(pdmodel_path[0], pdiparams_path[0])
4242
predictor_config.enable_memory_optim()
4343
predictor_config.enable_use_gpu(0, args.device)
44-
precision = args.trt_precision
44+
precision = args.precision
4545
max_batch_size = args.batch_size
4646
assert precision in ['FP32', 'FP16', 'INT8'], \
4747
'precision should be FP32/FP16/INT8'
@@ -54,12 +54,17 @@ def init_predictor(args):
5454
else:
5555
raise NotImplementedError
5656
predictor_config.enable_tensorrt_engine(
57-
workspace_size=args.trt_workspace_size,
57+
workspace_size=args.workspace_size,
5858
max_batch_size=max_batch_size,
59-
min_subgraph_size=args.trt_min_subgraph_size,
59+
min_subgraph_size=args.min_subgraph_size,
6060
precision_mode=precision_mode,
61-
use_static=args.trt_use_static,
62-
use_calib_mode=args.trt_use_calib_mode)
61+
use_static=args.use_static,
62+
use_calib_mode=args.use_calib_mode)
63+
predictor_config.set_trt_dynamic_shape_info(
64+
{"data": (1,) + tuple(args.image_shape)},
65+
{"data": (args.batch_size,) + tuple(args.image_shape)},
66+
{"data": (args.batch_size,) + tuple(args.image_shape)},
67+
)
6368
predictor = create_predictor(predictor_config)
6469
return predictor
6570

@@ -140,7 +145,7 @@ def benchmark_dataset(args):
140145
quantile = np.quantile(latency, [0.9, 0.95, 0.99])
141146

142147
statistics = {
143-
'precision': args.trt_precision,
148+
'precision': args.precision,
144149
'batch_size': batch_size,
145150
'throughput': total_images / (end - start),
146151
'accuracy': correct_predict / total_images,
@@ -189,7 +194,7 @@ def benchmark_synthetic(args):
189194
quantile = np.quantile(latency, [0.9, 0.95, 0.99])
190195

191196
statistics = {
192-
'precision': args.trt_precision,
197+
'precision': args.precision,
193198
'batch_size': batch_size,
194199
'throughput': args.benchmark_steps * batch_size / (end - start),
195200
'eval_latency_avg': np.mean(latency),
@@ -200,11 +205,11 @@ def benchmark_synthetic(args):
200205
return statistics
201206

202207
def main(args):
203-
setup_dllogger(args.trt_log_path)
208+
setup_dllogger(args.report_file)
204209
if args.show_config:
205210
print_args(args)
206211

207-
if args.trt_use_synthetic:
212+
if args.use_synthetic:
208213
statistics = benchmark_synthetic(args)
209214
else:
210215
statistics = benchmark_dataset(args)
@@ -213,4 +218,4 @@ def main(args):
213218

214219

215220
if __name__ == '__main__':
216-
main(parse_args(including_trt=True))
221+
main(parse_args(script='inference'))

PaddlePaddle/Classification/RN50v1.5/program.py

+1
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ def dist_optimizer(args, optimizer):
188188
}
189189

190190
dist_strategy.asp = args.asp
191+
dist_strategy.qat = args.qat
191192

192193
optimizer = fleet.distributed_optimizer(optimizer, strategy=dist_strategy)
193194

PaddlePaddle/Classification/RN50v1.5/scripts/inference/infer_resnet50_AMP.sh

+3-3
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414

1515
python inference.py \
1616
--data-layout NHWC \
17-
--trt-inference-dir ./inference_amp \
18-
--trt-precision FP16 \
17+
--inference-dir ./inference_amp \
18+
--precision FP16 \
1919
--batch-size 256 \
2020
--benchmark-steps 1024 \
2121
--benchmark-warmup-steps 16 \
22-
--trt-use-synthetic True
22+
--use-synthetic True

PaddlePaddle/Classification/RN50v1.5/scripts/inference/export_resnet50_TF32.sh renamed to PaddlePaddle/Classification/RN50v1.5/scripts/inference/infer_resnet50_QAT.sh

+8-7
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
CKPT=${1:-"./output/ResNet50/89"}
16-
MODEL_PREFIX=${2:-"resnet_50_paddle"}
17-
18-
python -m paddle.distributed.launch --gpus=0 export_model.py \
19-
--trt-inference-dir ./inference_tf32 \
20-
--from-checkpoint $CKPT \
21-
--model-prefix ${MODEL_PREFIX}
15+
python inference.py \
16+
--data-layout NHWC \
17+
--inference-dir ./inference_qat \
18+
--precision INT8 \
19+
--batch-size 256 \
20+
--benchmark-steps 1024 \
21+
--benchmark-warmup-steps 16 \
22+
--use-synthetic True

PaddlePaddle/Classification/RN50v1.5/scripts/inference/infer_resnet50_TF32.sh

+3-3
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313
# limitations under the License.
1414

1515
python inference.py \
16-
--trt-inference-dir ./inference_tf32 \
17-
--trt-precision FP32 \
16+
--inference-dir ./inference_tf32 \
17+
--precision FP32 \
1818
--dali-num-threads 8 \
1919
--batch-size 256 \
2020
--benchmark-steps 1024 \
2121
--benchmark-warmup-steps 16 \
22-
--trt-use-synthetic True
22+
--use-synthetic True

PaddlePaddle/Classification/RN50v1.5/scripts/training/train_resnet50_AMP_90E_DGXA100.sh

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,5 @@ python -m paddle.distributed.launch --gpus=0,1,2,3,4,5,6,7 train.py \
1818
--scale-loss 128.0 \
1919
--use-dynamic-loss-scaling \
2020
--data-layout NHWC \
21-
--fuse-resunit
21+
--fuse-resunit \
22+
--inference-dir ./inference_amp

PaddlePaddle/Classification/RN50v1.5/scripts/inference/export_resnet50_AMP.sh renamed to PaddlePaddle/Classification/RN50v1.5/scripts/training/train_resnet50_AMP_QAT_10E_DGXA100.sh

+11-6
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,14 @@
1515
CKPT=${1:-"./output/ResNet50/89"}
1616
MODEL_PREFIX=${2:-"resnet_50_paddle"}
1717

18-
python -m paddle.distributed.launch --gpus=0 export_model.py \
19-
--amp \
20-
--data-layout NHWC \
21-
--trt-inference-dir ./inference_amp \
22-
--from-checkpoint ${CKPT} \
23-
--model-prefix ${MODEL_PREFIX}
18+
python -m paddle.distributed.launch --gpus=0,1,2,3,4,5,6,7 train.py \
19+
--from-pretrained-params ${CKPT} \
20+
--model-prefix ${MODEL_PREFIX} \
21+
--epochs 10 \
22+
--amp \
23+
--scale-loss 128.0 \
24+
--use-dynamic-loss-scaling \
25+
--data-layout NHWC \
26+
--qat \
27+
--lr 0.00005 \
28+
--inference-dir ./inference_qat

PaddlePaddle/Classification/RN50v1.5/scripts/training/train_resnet50_TF32_90E_DGXA100.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
python -m paddle.distributed.launch --gpus=0,1,2,3,4,5,6,7 train.py --epochs 90
15+
python -m paddle.distributed.launch --gpus=0,1,2,3,4,5,6,7 train.py --epochs 90 --inference-dir ./inference_tf32

PaddlePaddle/Classification/RN50v1.5/train.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from paddle.static.amp.fp16_lists import AutoMixedPrecisionLists
2929
from paddle.static.amp.fp16_utils import cast_model_to_fp16
3030
from paddle.incubate import asp as sparsity
31+
from paddle.static.quantization.quanter import quant_aware
3132

3233

3334
class MetricSummary:
@@ -107,7 +108,7 @@ def main(args):
107108
eval_step_each_epoch = len(eval_dataloader)
108109
eval_prog = paddle.static.Program()
109110

110-
eval_fetchs, _, _, _ = program.build(
111+
eval_fetchs, _, eval_feeds, _ = program.build(
111112
args,
112113
eval_prog,
113114
startup_prog,
@@ -147,6 +148,14 @@ def main(args):
147148
sparsity.prune_model(train_prog, mask_algo=args.mask_algo)
148149
logging.info("Pruning model done.")
149150

151+
if args.qat:
152+
if args.run_scope == RunScope.EVAL_ONLY:
153+
eval_prog = quant_aware(eval_prog, device, for_test=True, return_program=True)
154+
else:
155+
optimizer.qat_init(
156+
device,
157+
test_program=eval_prog)
158+
150159
if eval_prog is not None:
151160
eval_prog = program.compile_prog(args, eval_prog, is_train=False)
152161

@@ -169,7 +178,7 @@ def main(args):
169178

170179
# Save a checkpoint
171180
if epoch_id % args.save_interval == 0:
172-
model_path = os.path.join(args.output_dir, args.model_arch_name)
181+
model_path = os.path.join(args.checkpoint_dir, args.model_arch_name)
173182
save_model(train_prog, model_path, epoch_id, args.model_prefix)
174183

175184
# Evaluation
@@ -190,6 +199,10 @@ def main(args):
190199
if eval_summary.is_updated:
191200
program.log_info((), eval_summary.metric_dict, Mode.EVAL)
192201

202+
if eval_prog is not None:
203+
model_path = os.path.join(args.inference_dir, args.model_arch_name)
204+
paddle.static.save_inference_model(model_path, [eval_feeds['data']], [eval_fetchs['label'][0]], exe, program=eval_prog)
205+
193206

194207
if __name__ == '__main__':
195208
paddle.enable_static()

0 commit comments

Comments
 (0)