@@ -191,26 +191,27 @@ __kernel void ref_convolution_bwd_data(__global SRC_DATA_T *diff_src,
191
191
}
192
192
}
193
193
194
- ACC_DATA_T sum_src ;
195
- #if WITH_SUM
196
- sum_src = TO_ACC (SRC_TO_REF (diff_src [SRC_OFF (n , g * IC + ic , id , ih , iw )]));
197
- #endif
198
-
199
- ACC_DATA_T accumulator = TO_ACC (d );
194
+ POST_OP_DATA_T tmp = d ;
200
195
201
196
#if WITH_SRC_SCALES
202
- accumulator *= src_scales [0 ];
197
+ tmp *= src_scales [0 ];
203
198
#endif
204
199
#if WITH_WEI_SCALES
205
200
#if WEI_SCALES_MASK == 0
206
- accumulator *= wei_scales [0 ];
201
+ tmp *= wei_scales [0 ];
207
202
#else
208
- accumulator *= wei_scales [g * IC + ic ];
203
+ tmp *= wei_scales [g * IC + ic ];
209
204
#endif
210
205
#endif
211
206
212
207
#if WITH_BIAS
213
- accumulator += BIA_TO_REF (bias [g * IC + ic ]);
208
+ tmp += (POST_OP_DATA_T )BIA_TO_REF (bias [g * IC + ic ]);
209
+ #endif
210
+
211
+ POST_OP_DATA_T sum_src ;
212
+ #if WITH_SUM
213
+ sum_src = (POST_OP_DATA_T )SUM_TO_REF (
214
+ AS_SUM_DATA_T (diff_src [SRC_OFF (n , g * IC + ic , id , ih , iw )]));
214
215
#endif
215
216
216
217
#if NDIMS == 3
@@ -230,19 +231,19 @@ __kernel void ref_convolution_bwd_data(__global SRC_DATA_T *diff_src,
230
231
const unsigned po_d3 = 0 ;
231
232
const unsigned po_d4 = 0 ;
232
233
#endif
233
- APPLY_POST_OPS_SERIAL (accumulator , ACC_DATA_T , sum_src , float , n , 1 ,
234
- g * IC + ic , 1 , po_d2 , 1 , po_d3 , 1 , po_d4 , 1 , 0 , 1 );
234
+ APPLY_POST_OPS_SERIAL (tmp , POST_OP_DATA_T , sum_src , POST_OP_DATA_T , n , 1 ,
235
+ g * IC + ic , 1 , po_d2 , 1 , po_d3 , 1 , po_d4 , 1 , 0 , 1 );
235
236
236
237
#if WITH_DST_SCALES
237
- accumulator /= dst_scales [0 ];
238
+ tmp /= dst_scales [0 ];
238
239
#endif
239
240
240
241
#if WITH_DST_ZPOINTS
241
242
const int dst_zp = dst_zpoints [WITH_DST_ZPOINTS_PER_OC ? g * IC + ic : 0 ];
242
- accumulator += dst_zp ;
243
+ tmp += dst_zp ;
243
244
#endif // WITH_DST_ZPOINTS
244
245
245
- diff_src [SRC_OFF (n , g * IC + ic , id , ih , iw )] = TO_SRC (accumulator );
246
+ diff_src [SRC_OFF (n , g * IC + ic , id , ih , iw )] = TO_SRC (tmp );
246
247
}
247
248
#endif
248
249
0 commit comments