Skip to content

Commit 0613385

Browse files
committedOct 10, 2024
xe: ref_conv: align BWD_D post-ops with FWD
1 parent 304c93f commit 0613385

File tree

1 file changed

+16
-15
lines changed

1 file changed

+16
-15
lines changed
 

‎src/gpu/intel/ocl/ref_convolution.cl

+16-15
Original file line numberDiff line numberDiff line change
@@ -191,26 +191,27 @@ __kernel void ref_convolution_bwd_data(__global SRC_DATA_T *diff_src,
191191
}
192192
}
193193

194-
ACC_DATA_T sum_src;
195-
#if WITH_SUM
196-
sum_src = TO_ACC(SRC_TO_REF(diff_src[SRC_OFF(n, g * IC + ic, id, ih, iw)]));
197-
#endif
198-
199-
ACC_DATA_T accumulator = TO_ACC(d);
194+
POST_OP_DATA_T tmp = d;
200195

201196
#if WITH_SRC_SCALES
202-
accumulator *= src_scales[0];
197+
tmp *= src_scales[0];
203198
#endif
204199
#if WITH_WEI_SCALES
205200
#if WEI_SCALES_MASK == 0
206-
accumulator *= wei_scales[0];
201+
tmp *= wei_scales[0];
207202
#else
208-
accumulator *= wei_scales[g * IC + ic];
203+
tmp *= wei_scales[g * IC + ic];
209204
#endif
210205
#endif
211206

212207
#if WITH_BIAS
213-
accumulator += BIA_TO_REF(bias[g * IC + ic]);
208+
tmp += (POST_OP_DATA_T)BIA_TO_REF(bias[g * IC + ic]);
209+
#endif
210+
211+
POST_OP_DATA_T sum_src;
212+
#if WITH_SUM
213+
sum_src = (POST_OP_DATA_T)SUM_TO_REF(
214+
AS_SUM_DATA_T(diff_src[SRC_OFF(n, g * IC + ic, id, ih, iw)]));
214215
#endif
215216

216217
#if NDIMS == 3
@@ -230,19 +231,19 @@ __kernel void ref_convolution_bwd_data(__global SRC_DATA_T *diff_src,
230231
const unsigned po_d3 = 0;
231232
const unsigned po_d4 = 0;
232233
#endif
233-
APPLY_POST_OPS_SERIAL(accumulator, ACC_DATA_T, sum_src, float, n, 1,
234-
g *IC + ic, 1, po_d2, 1, po_d3, 1, po_d4, 1, 0, 1);
234+
APPLY_POST_OPS_SERIAL(tmp, POST_OP_DATA_T, sum_src, POST_OP_DATA_T, n, 1,
235+
g * IC + ic, 1, po_d2, 1, po_d3, 1, po_d4, 1, 0, 1);
235236

236237
#if WITH_DST_SCALES
237-
accumulator /= dst_scales[0];
238+
tmp /= dst_scales[0];
238239
#endif
239240

240241
#if WITH_DST_ZPOINTS
241242
const int dst_zp = dst_zpoints[WITH_DST_ZPOINTS_PER_OC ? g * IC + ic : 0];
242-
accumulator += dst_zp;
243+
tmp += dst_zp;
243244
#endif // WITH_DST_ZPOINTS
244245

245-
diff_src[SRC_OFF(n, g * IC + ic, id, ih, iw)] = TO_SRC(accumulator);
246+
diff_src[SRC_OFF(n, g * IC + ic, id, ih, iw)] = TO_SRC(tmp);
246247
}
247248
#endif
248249

0 commit comments

Comments
 (0)