@@ -33,7 +33,7 @@ namespace sycl {
33
33
struct convolution_kernel_fwd_t {
34
34
static constexpr int max_supported_ndims = 6 ;
35
35
36
- convolution_kernel_fwd_t (const sycl_convolution_conf_t &conf,
36
+ convolution_kernel_fwd_t (const sycl_convolution_fwd_conf_t &conf,
37
37
::sycl::handler &cgh, const exec_ctx_t &ctx)
38
38
: conf_(conf)
39
39
, data_(CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_SRC_0))
@@ -191,9 +191,8 @@ struct convolution_kernel_fwd_t {
191
191
accumulator *= sm_weights;
192
192
}
193
193
194
- if (bias_md ().ndims () != 0 ) {
195
- auto bias = load_float_value (
196
- bias_md ().data_type (), bias_ptr (), oc_tot);
194
+ if (conf_.has_bias ) {
195
+ auto bias = load_float_value (conf_.bias_dt , bias_ptr (), oc_tot);
197
196
accumulator += bias;
198
197
}
199
198
@@ -214,7 +213,6 @@ struct convolution_kernel_fwd_t {
214
213
private:
215
214
const xpu::sycl::md_t &data_md () const { return conf_.data_md ; }
216
215
const xpu::sycl::md_t &weights_md () const { return conf_.weights_md ; }
217
- const xpu::sycl::md_t &bias_md () const { return conf_.bias_md ; }
218
216
const xpu::sycl::md_t &dst_md () const { return conf_.dst_md ; }
219
217
220
218
void *data_ptr () const { return data_.get_pointer (); }
@@ -227,7 +225,7 @@ struct convolution_kernel_fwd_t {
227
225
void *data_zeropoint_ptr () const { return data_zeropoints_.get_pointer (); }
228
226
void *dst_zeropoint_ptr () const { return dst_zeropoints_.get_pointer (); }
229
227
230
- sycl_convolution_conf_t conf_;
228
+ sycl_convolution_fwd_conf_t conf_;
231
229
232
230
xpu::sycl::in_memory_arg_t data_;
233
231
xpu::sycl::in_memory_arg_t weights_;
@@ -247,7 +245,7 @@ struct convolution_kernel_fwd_t {
247
245
struct convolution_kernel_bwd_data_t {
248
246
static constexpr int max_supported_ndims = 6 ;
249
247
250
- convolution_kernel_bwd_data_t (const sycl_convolution_conf_t &conf,
248
+ convolution_kernel_bwd_data_t (const sycl_convolution_bwd_data_conf_t &conf,
251
249
::sycl::handler &cgh, const exec_ctx_t &ctx)
252
250
: conf_(conf)
253
251
, diff_data_(CTX_INOUT_SYCL_KERNEL_MEMORY(DNNL_ARG_DIFF_SRC))
@@ -423,9 +421,8 @@ struct convolution_kernel_bwd_data_t {
423
421
accumulator *= sm_weights;
424
422
}
425
423
426
- if (bias_md ().ndims () != 0 ) {
427
- auto bias = load_float_value (
428
- bias_md ().data_type (), bias_ptr (), ic_tot);
424
+ if (conf_.has_bias ) {
425
+ auto bias = load_float_value (conf_.bias_dt , bias_ptr (), ic_tot);
429
426
accumulator += bias;
430
427
}
431
428
@@ -446,7 +443,6 @@ struct convolution_kernel_bwd_data_t {
446
443
private:
447
444
const xpu::sycl::md_t &diff_data_md () const { return conf_.diff_data_md ; }
448
445
const xpu::sycl::md_t &weights_md () const { return conf_.weights_md ; }
449
- const xpu::sycl::md_t &bias_md () const { return conf_.bias_md ; }
450
446
const xpu::sycl::md_t &diff_dst_md () const { return conf_.diff_dst_md ; }
451
447
452
448
void *diff_data_ptr () const { return diff_data_.get_pointer (); }
@@ -459,7 +455,7 @@ struct convolution_kernel_bwd_data_t {
459
455
void *data_zeropoint_ptr () const { return data_zeropoints_.get_pointer (); }
460
456
void *dst_zeropoint_ptr () const { return dst_zeropoints_.get_pointer (); }
461
457
462
- sycl_convolution_conf_t conf_;
458
+ sycl_convolution_bwd_data_conf_t conf_;
463
459
464
460
xpu::sycl::inout_memory_arg_t diff_data_;
465
461
xpu::sycl::in_memory_arg_t weights_;
@@ -479,7 +475,8 @@ struct convolution_kernel_bwd_data_t {
479
475
struct convolution_kernel_bwd_weights_t {
480
476
static constexpr int max_supported_ndims = 6 ;
481
477
482
- convolution_kernel_bwd_weights_t (const sycl_convolution_conf_t &conf,
478
+ convolution_kernel_bwd_weights_t (
479
+ const sycl_convolution_bwd_weights_conf_t &conf,
483
480
::sycl::handler &cgh, const exec_ctx_t &ctx, int data_arg,
484
481
int diff_dst_arg)
485
482
: conf_(conf)
@@ -572,8 +569,8 @@ struct convolution_kernel_bwd_weights_t {
572
569
}
573
570
}
574
571
}
575
- store_float_value (diff_bias_md (). data_type () ,
576
- accumulator_bias, diff_bias_ptr (), g * OC + oc);
572
+ store_float_value (conf_. bias_dt , accumulator_bias ,
573
+ diff_bias_ptr (), g * OC + oc);
577
574
}
578
575
};
579
576
if (conf_.is_deconvolution ) {
@@ -624,15 +621,14 @@ struct convolution_kernel_bwd_weights_t {
624
621
const xpu::sycl::md_t &diff_weights_md () const {
625
622
return conf_.diff_weights_md ;
626
623
}
627
- const xpu::sycl::md_t &diff_bias_md () const { return conf_.diff_bias_md ; }
628
624
const xpu::sycl::md_t &diff_dst_md () const { return conf_.diff_dst_md ; }
629
625
630
626
void *data_ptr () const { return data_.get_pointer (); }
631
627
void *diff_weights_ptr () const { return diff_weights_.get_pointer (); }
632
628
void *diff_bias_ptr () const { return diff_bias_.get_pointer (); }
633
629
void *diff_dst_ptr () const { return diff_dst_.get_pointer (); }
634
630
635
- sycl_convolution_conf_t conf_;
631
+ sycl_convolution_bwd_weights_conf_t conf_;
636
632
637
633
xpu::sycl::in_memory_arg_t data_;
638
634
xpu::sycl::out_memory_arg_t diff_weights_;
0 commit comments