@@ -100,7 +100,8 @@ struct ref_convolution_fwd_t : public gpu::generic::sycl::primitive_t {
100
100
&& IMPLICATION (!attr ()->scales_ .has_default_values (),
101
101
attr_scales_ok ()
102
102
&& check_convolution_scales_types (attr ()))
103
- && sycl_post_ops_t::post_ops_ok (attr (), false );
103
+ && sycl_post_ops_t::post_ops_ok (attr (), false )
104
+ && set_default_alg_kind (alg_kind::convolution_direct);
104
105
if (!ok) return status::unimplemented;
105
106
106
107
return init_conf ();
@@ -156,7 +157,8 @@ struct ref_convolution_bwd_data_t : public gpu::generic::sycl::primitive_t {
156
157
| sm::zero_points_runtime | sm::sum_dt)
157
158
&& IMPLICATION (!attr ()->scales_ .has_default_values (),
158
159
attr_scales_ok ()
159
- && check_convolution_scales_types (attr ()));
160
+ && check_convolution_scales_types (attr ()))
161
+ && set_default_alg_kind (alg_kind::convolution_direct);
160
162
if (!ok) return status::unimplemented;
161
163
162
164
return init_conf ();
@@ -195,7 +197,6 @@ struct ref_convolution_bwd_weights_t : public gpu::generic::sycl::primitive_t {
195
197
196
198
status_t init (impl::engine_t *engine) {
197
199
using namespace data_type ;
198
- using sm = primitive_attr_t ::skip_mask_t ;
199
200
200
201
const memory_desc_wrapper data_d (src_md ());
201
202
const memory_desc_wrapper diff_weights_d (diff_weights_md ());
@@ -208,11 +209,8 @@ struct ref_convolution_bwd_weights_t : public gpu::generic::sycl::primitive_t {
208
209
data_d, diff_weights_d, diff_dst_d)
209
210
&& check_convolution_formats (
210
211
data_d, diff_weights_d, diff_dst_d)
211
- && attr ()->has_default_values (sm::scales_runtime
212
- | sm::zero_points_runtime | sm::sum_dt)
213
- && IMPLICATION (!attr ()->scales_ .has_default_values (),
214
- attr_scales_ok ()
215
- && check_convolution_scales_types (attr ()));
212
+ && attr ()->has_default_values ()
213
+ && set_default_alg_kind (alg_kind::convolution_direct);
216
214
if (!ok) return status::unimplemented;
217
215
218
216
return init_conf ();
0 commit comments