Skip to content

Commit d446428

Browse files
committed
src, tests: softmax: fix incorrect outer_stride variable
Previous logic didn't account for situations when neighbor dimension could have larger stride than it should be, like for acbd format.
1 parent b1d67fb commit d446428

File tree

6 files changed

+19
-9
lines changed

6 files changed

+19
-9
lines changed

src/common/softmax_pd.hpp

-4
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,6 @@ struct softmax_pd_t : public primitive_desc_t {
8181
dst_desc().dims + axis() + 1, ndims() - 1 - axis());
8282
}
8383

84-
dim_t outer_stride() const {
85-
const memory_desc_wrapper dst_d(dst_desc());
86-
return axis() > 0 ? dst_d.blocking_desc().strides[axis() - 1] : 1;
87-
}
8884
dim_t axis_stride() const {
8985
const memory_desc_wrapper dst_d(dst_desc());
9086
return dst_d.blocking_desc().strides[axis()];

src/common/verbose.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -1411,7 +1411,7 @@ std::string init_info_softmax(const engine_t *e, const pd_t *pd) {
14111411

14121412
ss << "src_" << md2fmt_str(src_md, pd->invariant_src_user_format_kind());
14131413
ss << " dst_" << dst_md;
1414-
if (diff_dst_md) ss << " diff_dst_" << diff_dst_md;
1414+
if (!types::is_zero_md(diff_dst_md)) ss << " diff_dst_" << diff_dst_md;
14151415

14161416
ss << "," << pd->attr() << ",";
14171417
ss << "alg:" << pd->alg_kind() << " axis:" << pd->axis() << ",";

src/cpu/ref_softmax.cpp

+12-4
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,16 @@ status_t ref_softmax_fwd_t::execute_forward_dense(const exec_ctx_t &ctx) const {
5454
const memory_desc_wrapper dst_d(pd()->dst_md());
5555

5656
const auto interim_dt = data_type::f32;
57-
const dim_t ou_stride = pd()->outer_stride();
5857
const auto is_inplace = (src == dst);
5958
const auto has_padding = is_padding(dst_d);
6059
const auto zero_padding = has_padding && !is_inplace;
6160
const auto axis = pd()->axis();
6261
const auto axis_size = pd()->axis_size(true);
63-
const auto axis_blk_size = src_d.padded_dims()[axis] - src_d.dims()[axis];
62+
// Since dense implementation assumes `inner_size == 1`, and src is dense
63+
// and identical to dst, outer_stride should coincide with axis_size. This
64+
// allows to shuffle outer dimensions and not relying on a stride of a
65+
// previous dimension.
66+
const auto ou_stride = axis_size;
6467
const auto src_dt_size = types::data_type_size(pd()->src_md()->data_type);
6568
const auto dst_dt_size = types::data_type_size(pd()->dst_md()->data_type);
6669

@@ -178,8 +181,9 @@ status_t ref_softmax_fwd_t::execute_forward_dense(const exec_ctx_t &ctx) const {
178181
io::store_float_value(dst_d.data_type(), val, dst_data, c);
179182
}
180183
if (zero_padding) {
184+
const auto tail = src_d.padded_dims()[axis] - src_d.dims()[axis];
181185
PRAGMA_OMP_SIMD()
182-
for (int i = 0; i < axis_blk_size; i++)
186+
for (int i = 0; i < tail; i++)
183187
io::store_float_value(
184188
dst_d.data_type(), 0, dst_data, channels_ + i);
185189
}
@@ -315,7 +319,11 @@ status_t ref_softmax_bwd_t::execute_backward_dense(
315319
const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
316320
const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
317321

318-
const auto ou_stride = pd()->outer_stride();
322+
// Since dense implementation assumes `inner_size == 1`, and src is dense
323+
// and identical to dst, outer_stride should coincide with axis_size. This
324+
// allows to shuffle outer dimensions and not relying on a stride of a
325+
// previous dimension.
326+
const auto ou_stride = pd()->axis_size();
319327

320328
parallel_nd(outer_size_, [&](dim_t ou) {
321329
float sbr = 0;

tests/benchdnn/inputs/softmax/test_softmax_all

+2
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@
5555
--axis=0
5656
--batch=shapes_large
5757

58+
--reset --stag=acbd --dtag=acbd --sdt=f32 --ddt=f32 --axis=3 1x16x384x384_n"neighbor_dim_to_axis_has_larger_stride"
59+
5860
--batch=test_softmax_bfloat16
5961

6062
--batch=test_softmax_float16

tests/benchdnn/inputs/softmax/test_softmax_bfloat16

+2
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,5 @@
4444
--axis=0,1
4545
--batch=set_0d
4646
--batch=shapes_2d
47+
48+
--reset --stag=acbd --dtag=acbd --sdt=bf16 --ddt=bf16 --axis=3 1x16x384x384_n"neighbor_dim_to_axis_has_larger_stride"

tests/benchdnn/inputs/softmax/test_softmax_float16

+2
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,5 @@
3939
--axis=0,1
4040
--batch=set_0d
4141
--batch=shapes_2d
42+
43+
--reset --stag=acbd --dtag=acbd --sdt=f16 --ddt=f16 --axis=3 1x16x384x384_n"neighbor_dim_to_axis_has_larger_stride"

0 commit comments

Comments
 (0)