Skip to content

Commit 23914f0

Browse files
t4c1mgouicem
authored andcommitted
generic: gpu: convolution/deconvolution/softmax: add missing checks
1 parent 85c99cf commit 23914f0

File tree

3 files changed

+28
-12
lines changed

3 files changed

+28
-12
lines changed

src/gpu/generic/sycl/ref_convolution.hpp

+6-8
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@ struct ref_convolution_fwd_t : public gpu::generic::sycl::primitive_t {
100100
&& IMPLICATION(!attr()->scales_.has_default_values(),
101101
attr_scales_ok()
102102
&& check_convolution_scales_types(attr()))
103-
&& sycl_post_ops_t::post_ops_ok(attr(), false);
103+
&& sycl_post_ops_t::post_ops_ok(attr(), false)
104+
&& set_default_alg_kind(alg_kind::convolution_direct);
104105
if (!ok) return status::unimplemented;
105106

106107
return init_conf();
@@ -156,7 +157,8 @@ struct ref_convolution_bwd_data_t : public gpu::generic::sycl::primitive_t {
156157
| sm::zero_points_runtime | sm::sum_dt)
157158
&& IMPLICATION(!attr()->scales_.has_default_values(),
158159
attr_scales_ok()
159-
&& check_convolution_scales_types(attr()));
160+
&& check_convolution_scales_types(attr()))
161+
&& set_default_alg_kind(alg_kind::convolution_direct);
160162
if (!ok) return status::unimplemented;
161163

162164
return init_conf();
@@ -195,7 +197,6 @@ struct ref_convolution_bwd_weights_t : public gpu::generic::sycl::primitive_t {
195197

196198
status_t init(impl::engine_t *engine) {
197199
using namespace data_type;
198-
using sm = primitive_attr_t::skip_mask_t;
199200

200201
const memory_desc_wrapper data_d(src_md());
201202
const memory_desc_wrapper diff_weights_d(diff_weights_md());
@@ -208,11 +209,8 @@ struct ref_convolution_bwd_weights_t : public gpu::generic::sycl::primitive_t {
208209
data_d, diff_weights_d, diff_dst_d)
209210
&& check_convolution_formats(
210211
data_d, diff_weights_d, diff_dst_d)
211-
&& attr()->has_default_values(sm::scales_runtime
212-
| sm::zero_points_runtime | sm::sum_dt)
213-
&& IMPLICATION(!attr()->scales_.has_default_values(),
214-
attr_scales_ok()
215-
&& check_convolution_scales_types(attr()));
212+
&& attr()->has_default_values()
213+
&& set_default_alg_kind(alg_kind::convolution_direct);
216214
if (!ok) return status::unimplemented;
217215

218216
return init_conf();

src/gpu/generic/sycl/ref_deconvolution.hpp

+2-4
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ struct ref_deconvolution_bwd_weights_t
4444

4545
status_t init(impl::engine_t *engine) {
4646
using namespace data_type;
47-
using sm = primitive_attr_t::skip_mask_t;
4847

4948
const memory_desc_wrapper data_d(src_md());
5049
const memory_desc_wrapper diff_weights_d(diff_weights_md());
@@ -57,9 +56,8 @@ struct ref_deconvolution_bwd_weights_t
5756
data_d, diff_weights_d, diff_dst_d)
5857
&& check_convolution_formats(
5958
data_d, diff_weights_d, diff_dst_d)
60-
&& attr()->has_default_values(sm::scales_runtime
61-
| sm::zero_points_runtime | sm::post_ops
62-
| sm::sum_dt);
59+
&& attr()->has_default_values()
60+
&& desc()->alg_kind == alg_kind::deconvolution_direct;
6361
if (!ok) return status::unimplemented;
6462

6563
return init_conf();

src/gpu/generic/sycl/ref_softmax.hpp

+20
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ struct ref_sycl_softmax_fwd_t : public gpu::generic::sycl::primitive_t {
4848
&& sycl_post_ops_t::post_ops_ok(attr(), true, false)
4949
&& set_default_formats() == status::success
5050
&& attr_.set_default_formats(dst_md()) == status::success
51+
&& check_formats(diff_src_md(), diff_dst_md())
5152
&& md_dims_in_range(src_md());
5253

5354
if (!ok) return status::unimplemented;
@@ -70,6 +71,15 @@ struct ref_sycl_softmax_fwd_t : public gpu::generic::sycl::primitive_t {
7071
return utils::one_of(src, data_type::f32, data_type::bf16,
7172
data_type::f16, data_type::s8, data_type::u8);
7273
}
74+
75+
static bool check_formats(const memory_desc_wrapper &src,
76+
const memory_desc_wrapper &dst) {
77+
for (const auto &mdw : {src, dst}) {
78+
if (!mdw.is_plain()) return false;
79+
}
80+
81+
return true;
82+
}
7383
};
7484

7585
status_t init(impl::engine_t *engine) override;
@@ -101,12 +111,22 @@ struct ref_sycl_softmax_bwd_t : public gpu::generic::sycl::primitive_t {
101111
&& dst_md()->data_type == diff_dst_md()->data_type
102112
&& attr()->has_default_values()
103113
&& set_default_formats() == status::success
114+
&& check_formats(src_md(), dst_md())
104115
&& md_dims_in_range(diff_dst_md());
105116

106117
if (!ok) return status::unimplemented;
107118
return init_conf();
108119
}
109120

121+
static bool check_formats(const memory_desc_wrapper &src,
122+
const memory_desc_wrapper &dst) {
123+
for (const auto &mdw : {src, dst}) {
124+
if (!mdw.is_plain()) return false;
125+
}
126+
127+
return true;
128+
}
129+
110130
sycl_softmax_conf_t conf_;
111131
status_t init_conf();
112132
};

0 commit comments

Comments
 (0)