@@ -511,14 +511,14 @@ struct cudnn_convolution_impl_fwd_t : public cudnn_convolution_impl_base_t {
511
511
const float beta = 0 .0f ;
512
512
if (flip_formats) {
513
513
CUDNN_EXECUTE_FUNC_V (cudnnTransformTensor, handle, &alpha,
514
- reorder_dst_desc, src, &beta, descs[y] , dst);
514
+ reorder_dst_desc, src, &beta, y_fp32_desc , dst);
515
515
} else {
516
- CUDNN_EXECUTE_FUNC_V (cudnnTransformTensor, handle, &alpha, descs[y],
517
- src, &beta, reorder_dst_desc, dst);
516
+ CUDNN_EXECUTE_FUNC_V (cudnnTransformTensor, handle, &alpha,
517
+ y_fp32_desc, src, &beta, reorder_dst_desc, dst);
518
518
}
519
519
}
520
520
521
- void execute_f32_sum (cudnnHandle_t handle, void *y, void *y_fp32_data,
521
+ void execute_f32_dst_sum (cudnnHandle_t handle, void *y, void *y_fp32_data,
522
522
float alpha_, float beta_) const {
523
523
float alpha1 = 0 .0f ;
524
524
float alpha2 = alpha_;
@@ -528,6 +528,14 @@ struct cudnn_convolution_impl_fwd_t : public cudnn_convolution_impl_base_t {
528
528
y_fp32_data);
529
529
}
530
530
531
+ void execute_f32_src_sum (cudnnHandle_t handle, void *x, void *y,
532
+ float alpha_, float beta_) const {
533
+ float alpha = alpha_;
534
+ float beta = beta_;
535
+ CUDNN_EXECUTE_FUNC_V (cudnnAddTensor, handle, &alpha, descs[io::y], x,
536
+ &beta, y_fp32_desc, y);
537
+ }
538
+
531
539
void execute_eltwise (cudnnHandle_t handle, void *src, void *dst) const {
532
540
float alpha = 1 .0f ;
533
541
float beta = 0 .0f ;
@@ -551,8 +559,7 @@ struct cudnn_convolution_impl_fwd_t : public cudnn_convolution_impl_base_t {
551
559
const std::vector<void *> &args) const override {
552
560
auto x = args[0 ], weights = args[1 ], y = args[2 ], bias = args[3 ],
553
561
scratchpad = args[4 ], post_op_scratch = args[6 ],
554
- post_op_reorder = args[7 ], src_scale = args[8 ],
555
- wei_scale = args[9 ], dst_scale = args[10 ];
562
+ src_scale = args[7 ], wei_scale = args[8 ], dst_scale = args[9 ];
556
563
void *output = use_temp_dst_ ? post_op_scratch : y;
557
564
if (using_transformed_filter ()) {
558
565
auto w_scratch = args[5 ];
@@ -561,7 +568,7 @@ struct cudnn_convolution_impl_fwd_t : public cudnn_convolution_impl_base_t {
561
568
}
562
569
563
570
float *y_fp32_data = nullptr ;
564
- if (y_f32_is_required ()) { y_fp32_data = (float *)args[11 ]; }
571
+ if (y_f32_is_required ()) { y_fp32_data = (float *)args[10 ]; }
565
572
566
573
bool fused = conv_bias || conv_bias_eltwise;
567
574
@@ -581,7 +588,8 @@ struct cudnn_convolution_impl_fwd_t : public cudnn_convolution_impl_base_t {
581
588
}
582
589
}
583
590
584
- auto &y_desc = y_f32_is_required () ? y_fp32_desc : descs[io::y];
591
+ auto &y_desc = (y_f32_is_required () || use_temp_dst_) ? y_fp32_desc
592
+ : descs[io::y];
585
593
void *y_data = y_f32_is_required () ? y_fp32_data : output;
586
594
587
595
if (fused) {
@@ -619,12 +627,11 @@ struct cudnn_convolution_impl_fwd_t : public cudnn_convolution_impl_base_t {
619
627
switch (post_ops[i]) {
620
628
case dnnl_sum:
621
629
if (need_reorder) {
622
- execute_reorder (handle, y, post_op_reorder, true );
623
- execute_sum (handle, post_op_reorder, post_op_scratch,
624
- sum_scale, 1 .0f );
630
+ execute_f32_src_sum (
631
+ handle, y, post_op_scratch, sum_scale, 1 .0f );
625
632
} else if (last_op) {
626
633
if (y_f32_is_required ()) {
627
- execute_f32_sum (
634
+ execute_f32_dst_sum (
628
635
handle, y, y_fp32_data, 1 .0f , sum_scale);
629
636
} else {
630
637
execute_sum (handle, post_op_scratch, y, 1 .0f ,
@@ -687,7 +694,7 @@ struct cudnn_convolution_impl_fwd_t : public cudnn_convolution_impl_base_t {
687
694
// The scratchpad size will need to be modified in
688
695
// cases where the dst_scaling is used and the output
689
696
// uses s8 values.
690
- if (use_scales_dst_) {
697
+ if (use_scales_dst_ || use_temp_dst_ ) {
691
698
CHECK (create_and_set_tensor_descriptor (&y_fp32_desc,
692
699
CUDNN_DATA_FLOAT, ndims[y], dims[y], strides[y]));
693
700
CHECK (CUDNN_EXECUTE_FUNC_S (cudnnGetConvolutionForwardWorkspaceSize,
0 commit comments