@@ -54,13 +54,16 @@ status_t ref_softmax_fwd_t::execute_forward_dense(const exec_ctx_t &ctx) const {
54
54
const memory_desc_wrapper dst_d (pd ()->dst_md ());
55
55
56
56
const auto interim_dt = data_type::f32;
57
- const dim_t ou_stride = pd ()->outer_stride ();
58
57
const auto is_inplace = (src == dst);
59
58
const auto has_padding = is_padding (dst_d);
60
59
const auto zero_padding = has_padding && !is_inplace;
61
60
const auto axis = pd ()->axis ();
62
61
const auto axis_size = pd ()->axis_size (true );
63
- const auto axis_blk_size = src_d.padded_dims ()[axis] - src_d.dims ()[axis];
62
+ // Since dense implementation assumes `inner_size == 1`, and src is dense
63
+ // and identical to dst, outer_stride should coincide with axis_size. This
64
+ // allows to shuffle outer dimensions and not relying on a stride of a
65
+ // previous dimension.
66
+ const auto ou_stride = axis_size;
64
67
const auto src_dt_size = types::data_type_size (pd ()->src_md ()->data_type );
65
68
const auto dst_dt_size = types::data_type_size (pd ()->dst_md ()->data_type );
66
69
@@ -178,8 +181,9 @@ status_t ref_softmax_fwd_t::execute_forward_dense(const exec_ctx_t &ctx) const {
178
181
io::store_float_value (dst_d.data_type (), val, dst_data, c);
179
182
}
180
183
if (zero_padding) {
184
+ const auto tail = src_d.padded_dims ()[axis] - src_d.dims ()[axis];
181
185
PRAGMA_OMP_SIMD ()
182
- for (int i = 0 ; i < axis_blk_size ; i++)
186
+ for (int i = 0 ; i < tail ; i++)
183
187
io::store_float_value (
184
188
dst_d.data_type (), 0 , dst_data, channels_ + i);
185
189
}
@@ -315,7 +319,11 @@ status_t ref_softmax_bwd_t::execute_backward_dense(
315
319
const memory_desc_wrapper diff_dst_d (pd ()->diff_dst_md ());
316
320
const memory_desc_wrapper diff_src_d (pd ()->diff_src_md ());
317
321
318
- const auto ou_stride = pd ()->outer_stride ();
322
+ // Since dense implementation assumes `inner_size == 1`, and src is dense
323
+ // and identical to dst, outer_stride should coincide with axis_size. This
324
+ // allows to shuffle outer dimensions and not relying on a stride of a
325
+ // previous dimension.
326
+ const auto ou_stride = pd ()->axis_size ();
319
327
320
328
parallel_nd (outer_size_, [&](dim_t ou) {
321
329
float sbr = 0 ;
0 commit comments