Skip to content

Commit 729ece5

Browse files
committed
generic: conv: deconv: reduce kernel argument size
1 parent 2bf5ffc commit 729ece5

6 files changed

+57
-38
lines changed

src/gpu/generic/sycl/convolution_kernels.hpp

+13-17
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ namespace sycl {
3333
struct convolution_kernel_fwd_t {
3434
static constexpr int max_supported_ndims = 6;
3535

36-
convolution_kernel_fwd_t(const sycl_convolution_conf_t &conf,
36+
convolution_kernel_fwd_t(const sycl_convolution_fwd_conf_t &conf,
3737
::sycl::handler &cgh, const exec_ctx_t &ctx)
3838
: conf_(conf)
3939
, data_(CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_SRC_0))
@@ -191,9 +191,8 @@ struct convolution_kernel_fwd_t {
191191
accumulator *= sm_weights;
192192
}
193193

194-
if (bias_md().ndims() != 0) {
195-
auto bias = load_float_value(
196-
bias_md().data_type(), bias_ptr(), oc_tot);
194+
if (conf_.has_bias) {
195+
auto bias = load_float_value(conf_.bias_dt, bias_ptr(), oc_tot);
197196
accumulator += bias;
198197
}
199198

@@ -214,7 +213,6 @@ struct convolution_kernel_fwd_t {
214213
private:
215214
const xpu::sycl::md_t &data_md() const { return conf_.data_md; }
216215
const xpu::sycl::md_t &weights_md() const { return conf_.weights_md; }
217-
const xpu::sycl::md_t &bias_md() const { return conf_.bias_md; }
218216
const xpu::sycl::md_t &dst_md() const { return conf_.dst_md; }
219217

220218
void *data_ptr() const { return data_.get_pointer(); }
@@ -227,7 +225,7 @@ struct convolution_kernel_fwd_t {
227225
void *data_zeropoint_ptr() const { return data_zeropoints_.get_pointer(); }
228226
void *dst_zeropoint_ptr() const { return dst_zeropoints_.get_pointer(); }
229227

230-
sycl_convolution_conf_t conf_;
228+
sycl_convolution_fwd_conf_t conf_;
231229

232230
xpu::sycl::in_memory_arg_t data_;
233231
xpu::sycl::in_memory_arg_t weights_;
@@ -247,7 +245,7 @@ struct convolution_kernel_fwd_t {
247245
struct convolution_kernel_bwd_data_t {
248246
static constexpr int max_supported_ndims = 6;
249247

250-
convolution_kernel_bwd_data_t(const sycl_convolution_conf_t &conf,
248+
convolution_kernel_bwd_data_t(const sycl_convolution_bwd_data_conf_t &conf,
251249
::sycl::handler &cgh, const exec_ctx_t &ctx)
252250
: conf_(conf)
253251
, diff_data_(CTX_INOUT_SYCL_KERNEL_MEMORY(DNNL_ARG_DIFF_SRC))
@@ -423,9 +421,8 @@ struct convolution_kernel_bwd_data_t {
423421
accumulator *= sm_weights;
424422
}
425423

426-
if (bias_md().ndims() != 0) {
427-
auto bias = load_float_value(
428-
bias_md().data_type(), bias_ptr(), ic_tot);
424+
if (conf_.has_bias) {
425+
auto bias = load_float_value(conf_.bias_dt, bias_ptr(), ic_tot);
429426
accumulator += bias;
430427
}
431428

@@ -446,7 +443,6 @@ struct convolution_kernel_bwd_data_t {
446443
private:
447444
const xpu::sycl::md_t &diff_data_md() const { return conf_.diff_data_md; }
448445
const xpu::sycl::md_t &weights_md() const { return conf_.weights_md; }
449-
const xpu::sycl::md_t &bias_md() const { return conf_.bias_md; }
450446
const xpu::sycl::md_t &diff_dst_md() const { return conf_.diff_dst_md; }
451447

452448
void *diff_data_ptr() const { return diff_data_.get_pointer(); }
@@ -459,7 +455,7 @@ struct convolution_kernel_bwd_data_t {
459455
void *data_zeropoint_ptr() const { return data_zeropoints_.get_pointer(); }
460456
void *dst_zeropoint_ptr() const { return dst_zeropoints_.get_pointer(); }
461457

462-
sycl_convolution_conf_t conf_;
458+
sycl_convolution_bwd_data_conf_t conf_;
463459

464460
xpu::sycl::inout_memory_arg_t diff_data_;
465461
xpu::sycl::in_memory_arg_t weights_;
@@ -479,7 +475,8 @@ struct convolution_kernel_bwd_data_t {
479475
struct convolution_kernel_bwd_weights_t {
480476
static constexpr int max_supported_ndims = 6;
481477

482-
convolution_kernel_bwd_weights_t(const sycl_convolution_conf_t &conf,
478+
convolution_kernel_bwd_weights_t(
479+
const sycl_convolution_bwd_weights_conf_t &conf,
483480
::sycl::handler &cgh, const exec_ctx_t &ctx, int data_arg,
484481
int diff_dst_arg)
485482
: conf_(conf)
@@ -572,8 +569,8 @@ struct convolution_kernel_bwd_weights_t {
572569
}
573570
}
574571
}
575-
store_float_value(diff_bias_md().data_type(),
576-
accumulator_bias, diff_bias_ptr(), g * OC + oc);
572+
store_float_value(conf_.bias_dt, accumulator_bias,
573+
diff_bias_ptr(), g * OC + oc);
577574
}
578575
};
579576
if (conf_.is_deconvolution) {
@@ -624,15 +621,14 @@ struct convolution_kernel_bwd_weights_t {
624621
const xpu::sycl::md_t &diff_weights_md() const {
625622
return conf_.diff_weights_md;
626623
}
627-
const xpu::sycl::md_t &diff_bias_md() const { return conf_.diff_bias_md; }
628624
const xpu::sycl::md_t &diff_dst_md() const { return conf_.diff_dst_md; }
629625

630626
void *data_ptr() const { return data_.get_pointer(); }
631627
void *diff_weights_ptr() const { return diff_weights_.get_pointer(); }
632628
void *diff_bias_ptr() const { return diff_bias_.get_pointer(); }
633629
void *diff_dst_ptr() const { return diff_dst_.get_pointer(); }
634630

635-
sycl_convolution_conf_t conf_;
631+
sycl_convolution_bwd_weights_conf_t conf_;
636632

637633
xpu::sycl::in_memory_arg_t data_;
638634
xpu::sycl::out_memory_arg_t diff_weights_;

src/gpu/generic/sycl/ref_convolution.cpp

+13-6
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,14 @@ namespace generic {
2525
namespace sycl {
2626

2727
status_t ref_convolution_fwd_t::pd_t::init_conf() {
28-
conf_ = sycl_convolution_conf_t();
28+
conf_ = sycl_convolution_fwd_conf_t();
2929

3030
conf_.data_md = xpu::sycl::md_t(src_md());
3131
conf_.weights_md = xpu::sycl::md_t(weights_md(0));
32-
if (with_bias()) { conf_.bias_md = xpu::sycl::md_t(weights_md(1)); }
32+
if (with_bias()) {
33+
conf_.bias_dt = weights_md(1)->data_type;
34+
conf_.has_bias = true;
35+
}
3336
conf_.dst_md = xpu::sycl::md_t(dst_md());
3437
conf_.ndims = ndims();
3538

@@ -85,11 +88,14 @@ status_t ref_convolution_fwd_t::execute(const exec_ctx_t &ctx) const {
8588
}
8689

8790
status_t ref_convolution_bwd_data_t::pd_t::init_conf() {
88-
conf_ = sycl_convolution_conf_t();
91+
conf_ = sycl_convolution_bwd_data_conf_t();
8992

9093
conf_.diff_data_md = xpu::sycl::md_t(diff_src_md());
9194
conf_.weights_md = xpu::sycl::md_t(weights_md(0));
92-
if (with_bias()) { conf_.bias_md = xpu::sycl::md_t(weights_md(1)); }
95+
if (with_bias()) {
96+
conf_.bias_dt = weights_md(1)->data_type;
97+
conf_.has_bias = true;
98+
}
9399
conf_.diff_dst_md = xpu::sycl::md_t(diff_dst_md());
94100
conf_.ndims = ndims();
95101

@@ -145,12 +151,13 @@ status_t ref_convolution_bwd_data_t::execute(const exec_ctx_t &ctx) const {
145151
}
146152

147153
status_t ref_convolution_bwd_weights_t::pd_t::init_conf() {
148-
conf_ = sycl_convolution_conf_t();
154+
conf_ = sycl_convolution_bwd_weights_conf_t();
149155

150156
conf_.data_md = xpu::sycl::md_t(src_md());
151157
conf_.diff_weights_md = xpu::sycl::md_t(diff_weights_md(0));
152158
if (with_bias()) {
153-
conf_.diff_bias_md = xpu::sycl::md_t(diff_weights_md(1));
159+
conf_.bias_dt = diff_weights_md(1)->data_type;
160+
conf_.has_bias = true;
154161
}
155162
conf_.diff_dst_md = xpu::sycl::md_t(diff_dst_md());
156163
conf_.ndims = ndims();

src/gpu/generic/sycl/ref_convolution.hpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ struct ref_convolution_fwd_t : public gpu::generic::sycl::primitive_t {
107107
return init_conf();
108108
}
109109

110-
sycl_convolution_conf_t conf_;
110+
sycl_convolution_fwd_conf_t conf_;
111111

112112
private:
113113
status_t init_conf();
@@ -164,7 +164,7 @@ struct ref_convolution_bwd_data_t : public gpu::generic::sycl::primitive_t {
164164
return init_conf();
165165
}
166166

167-
sycl_convolution_conf_t conf_;
167+
sycl_convolution_bwd_data_conf_t conf_;
168168

169169
private:
170170
status_t init_conf();
@@ -216,7 +216,7 @@ struct ref_convolution_bwd_weights_t : public gpu::generic::sycl::primitive_t {
216216
return init_conf();
217217
}
218218

219-
sycl_convolution_conf_t conf_;
219+
sycl_convolution_bwd_weights_conf_t conf_;
220220

221221
private:
222222
status_t init_conf();

src/gpu/generic/sycl/ref_deconvolution.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,12 @@ namespace generic {
2525
namespace sycl {
2626

2727
status_t ref_deconvolution_bwd_weights_t::pd_t::init_conf() {
28-
conf_ = sycl_convolution_conf_t();
28+
conf_ = sycl_convolution_bwd_weights_conf_t();
2929

3030
conf_.diff_dst_md = xpu::sycl::md_t(src_md());
3131
if (with_bias()) {
32-
conf_.diff_bias_md = xpu::sycl::md_t(diff_weights_md(1));
32+
conf_.bias_dt = diff_weights_md(1)->data_type;
33+
conf_.has_bias = true;
3334
}
3435
conf_.data_md = xpu::sycl::md_t(diff_dst_md());
3536
conf_.ndims = ndims();

src/gpu/generic/sycl/ref_deconvolution.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ struct ref_deconvolution_bwd_weights_t
6363
return init_conf();
6464
}
6565

66-
sycl_convolution_conf_t conf_;
66+
sycl_convolution_bwd_weights_conf_t conf_;
6767

6868
private:
6969
status_t init_conf();

src/gpu/generic/sycl/sycl_primitive_conf.hpp

+24-9
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,9 @@ struct sycl_binary_conf_t {
4747
sycl_post_ops_t post_ops;
4848
};
4949

50-
struct sycl_convolution_conf_t {
51-
xpu::sycl::md_t data_md;
52-
xpu::sycl::md_t dst_md;
53-
xpu::sycl::md_t weights_md;
54-
xpu::sycl::md_t bias_md;
55-
xpu::sycl::md_t diff_data_md;
56-
xpu::sycl::md_t diff_dst_md;
57-
xpu::sycl::md_t diff_weights_md;
58-
xpu::sycl::md_t diff_bias_md;
50+
struct sycl_convolution_common_conf_t {
51+
bool has_bias = false;
52+
data_type_t bias_dt;
5953

6054
int padding[3];
6155
int strides[3];
@@ -81,6 +75,24 @@ struct sycl_convolution_conf_t {
8175
sycl_post_ops_t post_ops;
8276
};
8377

78+
struct sycl_convolution_fwd_conf_t : sycl_convolution_common_conf_t {
79+
xpu::sycl::md_t data_md;
80+
xpu::sycl::md_t dst_md;
81+
xpu::sycl::md_t weights_md;
82+
};
83+
84+
struct sycl_convolution_bwd_data_conf_t : sycl_convolution_common_conf_t {
85+
xpu::sycl::md_t weights_md;
86+
xpu::sycl::md_t diff_data_md;
87+
xpu::sycl::md_t diff_dst_md;
88+
};
89+
90+
struct sycl_convolution_bwd_weights_conf_t : sycl_convolution_common_conf_t {
91+
xpu::sycl::md_t data_md;
92+
xpu::sycl::md_t diff_dst_md;
93+
xpu::sycl::md_t diff_weights_md;
94+
};
95+
8496
struct sycl_eltwise_conf_t {
8597
prop_kind_t prop_kind;
8698
xpu::sycl::md_t src_md;
@@ -416,6 +428,9 @@ CHECK_SYCL_KERNEL_ARG_TYPE(sycl_sum_conf_t);
416428
CHECK_SYCL_KERNEL_ARG_TYPE(sycl_pooling_base_conf_t);
417429
CHECK_SYCL_KERNEL_ARG_TYPE(sycl_pooling_fwd_conf_t);
418430
CHECK_SYCL_KERNEL_ARG_TYPE(sycl_pooling_bwd_conf_t);
431+
CHECK_SYCL_KERNEL_ARG_TYPE(sycl_convolution_fwd_conf_t);
432+
CHECK_SYCL_KERNEL_ARG_TYPE(sycl_convolution_bwd_data_conf_t);
433+
CHECK_SYCL_KERNEL_ARG_TYPE(sycl_convolution_bwd_weights_conf_t);
419434

420435
} // namespace sycl
421436
} // namespace generic

0 commit comments

Comments
 (0)