@@ -48,6 +48,7 @@ struct ref_sycl_softmax_fwd_t : public gpu::generic::sycl::primitive_t {
48
48
&& sycl_post_ops_t::post_ops_ok (attr (), true , false )
49
49
&& set_default_formats () == status::success
50
50
&& attr_.set_default_formats (dst_md ()) == status::success
51
+ && check_formats (src_md (), dst_md ())
51
52
&& md_dims_in_range (src_md ());
52
53
53
54
if (!ok) return status::unimplemented;
@@ -70,6 +71,15 @@ struct ref_sycl_softmax_fwd_t : public gpu::generic::sycl::primitive_t {
70
71
return utils::one_of (src, data_type::f32, data_type::bf16,
71
72
data_type::f16, data_type::s8, data_type::u8);
72
73
}
74
+
75
+ static bool check_formats (const memory_desc_wrapper &src,
76
+ const memory_desc_wrapper &dst) {
77
+ for (const auto &mdw : {src, dst}) {
78
+ if (!mdw.is_plain ()) return false ;
79
+ }
80
+
81
+ return true ;
82
+ }
73
83
};
74
84
75
85
status_t init (impl::engine_t *engine) override ;
@@ -101,12 +111,22 @@ struct ref_sycl_softmax_bwd_t : public gpu::generic::sycl::primitive_t {
101
111
&& dst_md ()->data_type == diff_dst_md ()->data_type
102
112
&& attr ()->has_default_values ()
103
113
&& set_default_formats () == status::success
114
+ && check_formats (src_md (), dst_md ())
104
115
&& md_dims_in_range (diff_dst_md ());
105
116
106
117
if (!ok) return status::unimplemented;
107
118
return init_conf ();
108
119
}
109
120
121
+ static bool check_formats (const memory_desc_wrapper &src,
122
+ const memory_desc_wrapper &dst) {
123
+ for (const auto &mdw : {src, dst}) {
124
+ if (!mdw.is_plain ()) return false ;
125
+ }
126
+
127
+ return true ;
128
+ }
129
+
110
130
sycl_softmax_conf_t conf_;
111
131
status_t init_conf ();
112
132
};
0 commit comments