Skip to content

Commit a6f4642

Browse files
committed
generic: gpu: convolution/deconvolution/softmax: add missing checks
1 parent 8393e5c commit a6f4642

File tree

3 files changed

+28
-9
lines changed

3 files changed

+28
-9
lines changed

src/gpu/generic/sycl/ref_convolution.hpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ struct ref_convolution_fwd_t : public gpu::generic::sycl::primitive_t {
9393
| sm::sum_dt)
9494
&& IMPLICATION(!attr()->scales_.has_default_values(),
9595
attr_scales_ok())
96-
&& sycl_post_ops_t::post_ops_ok(attr(), false);
96+
&& sycl_post_ops_t::post_ops_ok(attr(), false)
97+
&& set_default_alg_kind(alg_kind::convolution_direct);
9798
if (!ok) return status::unimplemented;
9899

99100
return init_conf();
@@ -148,7 +149,8 @@ struct ref_convolution_bwd_data_t : public gpu::generic::sycl::primitive_t {
148149
&& attr()->has_default_values(sm::scales_runtime
149150
| sm::zero_points_runtime | sm::sum_dt)
150151
&& IMPLICATION(!attr()->scales_.has_default_values(),
151-
attr_scales_ok());
152+
attr_scales_ok())
153+
&& set_default_alg_kind(alg_kind::convolution_direct);
152154
if (!ok) return status::unimplemented;
153155

154156
return init_conf();
@@ -200,10 +202,8 @@ struct ref_convolution_bwd_weights_t : public gpu::generic::sycl::primitive_t {
200202
data_d, diff_weights_d, diff_dst_d)
201203
&& check_convolution_formats(
202204
data_d, diff_weights_d, diff_dst_d)
203-
&& attr()->has_default_values(sm::scales_runtime
204-
| sm::zero_points_runtime | sm::sum_dt)
205-
&& IMPLICATION(!attr()->scales_.has_default_values(),
206-
attr_scales_ok());
205+
&& attr()->has_default_values()
206+
&& set_default_alg_kind(alg_kind::convolution_direct);
207207
if (!ok) return status::unimplemented;
208208

209209
return init_conf();

src/gpu/generic/sycl/ref_deconvolution.hpp

+2-3
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,8 @@ struct ref_deconvolution_bwd_weights_t
5757
data_d, diff_weights_d, diff_dst_d)
5858
&& check_convolution_formats(
5959
data_d, diff_weights_d, diff_dst_d)
60-
&& attr()->has_default_values(sm::scales_runtime
61-
| sm::zero_points_runtime | sm::post_ops
62-
| sm::sum_dt);
60+
&& attr()->has_default_values()
61+
&& desc()->alg_kind == alg_kind::deconvolution_direct;
6362
if (!ok) return status::unimplemented;
6463

6564
return init_conf();

src/gpu/generic/sycl/ref_softmax.hpp

+20
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ struct ref_sycl_softmax_fwd_t : public gpu::generic::sycl::primitive_t {
4848
&& sycl_post_ops_t::post_ops_ok(attr(), true, false)
4949
&& set_default_formats() == status::success
5050
&& attr_.set_default_formats(dst_md()) == status::success
51+
&& check_formats(diff_src_md(), diff_dst_md())
5152
&& md_dims_in_range(src_md());
5253

5354
if (!ok) return status::unimplemented;
@@ -70,6 +71,15 @@ struct ref_sycl_softmax_fwd_t : public gpu::generic::sycl::primitive_t {
7071
return utils::one_of(src, data_type::f32, data_type::bf16,
7172
data_type::f16, data_type::s8, data_type::u8);
7273
}
74+
75+
static bool check_formats(const memory_desc_wrapper &src,
76+
const memory_desc_wrapper &dst) {
77+
for (const auto &mdw : {src, dst}) {
78+
if (!mdw.is_plain()) return false;
79+
}
80+
81+
return true;
82+
}
7383
};
7484

7585
status_t init(impl::engine_t *engine) override;
@@ -101,12 +111,22 @@ struct ref_sycl_softmax_bwd_t : public gpu::generic::sycl::primitive_t {
101111
&& dst_md()->data_type == diff_dst_md()->data_type
102112
&& attr()->has_default_values()
103113
&& set_default_formats() == status::success
114+
&& check_formats(src_md(), dst_md())
104115
&& md_dims_in_range(diff_dst_md());
105116

106117
if (!ok) return status::unimplemented;
107118
return init_conf();
108119
}
109120

121+
static bool check_formats(const memory_desc_wrapper &src,
122+
const memory_desc_wrapper &dst) {
123+
for (const auto &mdw : {src, dst}) {
124+
if (!mdw.is_plain()) return false;
125+
}
126+
127+
return true;
128+
}
129+
110130
sycl_softmax_conf_t conf_;
111131
status_t init_conf();
112132
};

0 commit comments

Comments
 (0)