Skip to content

Commit da44aea

Browse files
t4c1mgouicem
authored andcommitted
generic: sycl: softmax: bugfix checks
1 parent dbf9278 commit da44aea

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/gpu/generic/sycl/ref_softmax.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ struct ref_sycl_softmax_fwd_t : public gpu::generic::sycl::primitive_t {
4848
&& sycl_post_ops_t::post_ops_ok(attr(), true, false)
4949
&& set_default_formats() == status::success
5050
&& attr_.set_default_formats(dst_md()) == status::success
51-
&& check_formats(diff_src_md(), diff_dst_md())
51+
&& check_formats(src_md(), dst_md())
5252
&& md_dims_in_range(src_md());
5353

5454
if (!ok) return status::unimplemented;
@@ -111,7 +111,7 @@ struct ref_sycl_softmax_bwd_t : public gpu::generic::sycl::primitive_t {
111111
&& dst_md()->data_type == diff_dst_md()->data_type
112112
&& attr()->has_default_values()
113113
&& set_default_formats() == status::success
114-
&& check_formats(src_md(), dst_md())
114+
&& check_formats(diff_src_md(), diff_dst_md())
115115
&& md_dims_in_range(diff_dst_md());
116116

117117
if (!ok) return status::unimplemented;

0 commit comments

Comments
 (0)