Skip to content

Commit 57568a4

Browse files
authored
gpu: generic: deconvolution: move implementation from ocl folder (#2144)
1 parent 2dfc0d6 commit 57568a4

File tree

3 files changed

+380
-293
lines changed

3 files changed

+380
-293
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,376 @@
1+
/*******************************************************************************
2+
* Copyright 2024 Intel Corporation
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*******************************************************************************/
16+
17+
#ifndef GPU_GENERIC_CONVOLUTION_DECONVOLUTION_HPP
18+
#define GPU_GENERIC_CONVOLUTION_DECONVOLUTION_HPP
19+
20+
#include "common/c_types_map.hpp"
21+
#include "common/primitive.hpp"
22+
#include "common/primitive_desc_iterator.hpp"
23+
#include "common/type_helpers.hpp"
24+
#include "common/utils.hpp"
25+
#include "gpu/gpu_deconvolution_pd.hpp"
26+
#include "gpu/gpu_primitive.hpp"
27+
28+
namespace dnnl {
29+
namespace impl {
30+
namespace gpu {
31+
namespace generic {
32+
33+
static status_t weights_axes_permutation(
34+
memory_desc_t *o_md, const memory_desc_t *i_md, bool with_groups) {
35+
int perm[DNNL_MAX_NDIMS] {}; // deconv to conv weight permutation
36+
for (int d = 0; d < DNNL_MAX_NDIMS; ++d)
37+
perm[d] = d;
38+
nstl::swap(perm[0 + with_groups], perm[1 + with_groups]);
39+
40+
return memory_desc_permute_axes(*o_md, *i_md, perm);
41+
}
42+
43+
static status_t conv_descr_create(
44+
const deconvolution_desc_t *dd, convolution_desc_t *cd) {
45+
using namespace prop_kind;
46+
alg_kind_t alg_kind = alg_kind::convolution_direct;
47+
48+
const memory_desc_t *src_md, *dst_md, *d_weights_d;
49+
prop_kind_t prop_kind;
50+
51+
switch (dd->prop_kind) {
52+
case forward:
53+
case forward_inference:
54+
prop_kind = backward_data;
55+
src_md = &dd->dst_desc;
56+
dst_md = &dd->src_desc;
57+
d_weights_d = &dd->weights_desc;
58+
break;
59+
case backward_data:
60+
prop_kind = forward_training;
61+
src_md = &dd->diff_dst_desc;
62+
dst_md = &dd->diff_src_desc;
63+
d_weights_d = &dd->weights_desc;
64+
break;
65+
case backward_weights:
66+
prop_kind = dd->prop_kind;
67+
src_md = &dd->diff_dst_desc;
68+
dst_md = &dd->src_desc;
69+
d_weights_d = &dd->diff_weights_desc;
70+
break;
71+
default: assert(!"unknown prop kind"); return status::invalid_arguments;
72+
}
73+
74+
// Create weights desc for convolution
75+
memory_desc_t c_weights_d;
76+
const bool with_groups = d_weights_d->ndims == src_md->ndims + 1;
77+
CHECK(weights_axes_permutation(&c_weights_d, d_weights_d, with_groups));
78+
79+
return conv_desc_init(cd, prop_kind, alg_kind, src_md, &c_weights_d,
80+
prop_kind != backward_weights ? &dd->bias_desc : nullptr, dst_md,
81+
dd->strides, dd->dilates, dd->padding[0], dd->padding[1]);
82+
}
83+
84+
struct convolution_deconvolution_fwd_t : public gpu::primitive_t {
85+
using gpu::primitive_t::primitive_t;
86+
struct pd_t : public gpu_deconvolution_fwd_pd_t {
87+
pd_t(const deconvolution_desc_t *adesc, const primitive_attr_t *attr,
88+
const deconvolution_fwd_pd_t *hint_fwd_pd)
89+
: gpu_deconvolution_fwd_pd_t(adesc, attr, hint_fwd_pd) {}
90+
91+
pd_t(const pd_t &other) = default;
92+
93+
~pd_t() = default;
94+
95+
DECLARE_COMMON_PD_T(name_.c_str(), convolution_deconvolution_fwd_t);
96+
status_t init_convolution(impl::engine_t *engine) {
97+
convolution_desc_t cd;
98+
CHECK(conv_descr_create(desc(), &cd));
99+
primitive_attr_t conv_attr(*attr());
100+
if (!conv_attr.is_initialized()) return status::out_of_memory;
101+
primitive_desc_iterator_t it(
102+
engine, (op_desc_t *)&cd, &conv_attr, nullptr);
103+
if (!it.is_initialized()) return status::out_of_memory;
104+
conv_pd_ = *(++it);
105+
106+
return (conv_pd_) ? status::success : status::unimplemented;
107+
}
108+
109+
status_t init(impl::engine_t *engine) {
110+
using namespace format_tag;
111+
using sm = primitive_attr_t::skip_mask_t;
112+
113+
const auto attr_skip_mask = sm::post_ops | sm::zero_points_runtime
114+
| sm::scales_runtime;
115+
116+
VDISPATCH_DECONVOLUTION(is_fwd(), VERBOSE_BAD_PROPKIND);
117+
VDISPATCH_DECONVOLUTION(
118+
desc()->alg_kind == alg_kind::deconvolution_direct,
119+
VERBOSE_BAD_ALGORITHM);
120+
VDISPATCH_DECONVOLUTION(attr()->has_default_values(attr_skip_mask),
121+
VERBOSE_UNSUPPORTED_ATTR);
122+
VDISPATCH_DECONVOLUTION(
123+
(utils::everyone_is(data_type::f32,
124+
desc()->src_desc.data_type,
125+
desc()->weights_desc.data_type,
126+
desc()->dst_desc.data_type)
127+
|| (utils::everyone_is(data_type::f64,
128+
desc()->src_desc.data_type,
129+
desc()->weights_desc.data_type,
130+
desc()->dst_desc.data_type))
131+
|| ((utils::everyone_is(data_type::f16,
132+
desc()->src_desc.data_type,
133+
desc()->weights_desc.data_type)
134+
|| utils::everyone_is(data_type::f32,
135+
desc()->src_desc.data_type,
136+
desc()->weights_desc.data_type)
137+
|| utils::everyone_is(data_type::bf16,
138+
desc()->src_desc.data_type,
139+
desc()->weights_desc.data_type))
140+
&& utils::one_of(desc()->dst_desc.data_type,
141+
data_type::f16, data_type::u8,
142+
data_type::s8))
143+
|| (utils::everyone_is(data_type::bf16,
144+
desc()->src_desc.data_type,
145+
desc()->weights_desc.data_type)
146+
&& utils::one_of(desc()->dst_desc.data_type,
147+
data_type::f32, data_type::bf16))
148+
|| (utils::everyone_is(data_type::f16,
149+
desc()->src_desc.data_type,
150+
desc()->weights_desc.data_type)
151+
&& utils::one_of(desc()->dst_desc.data_type,
152+
data_type::f32, data_type::f16))
153+
|| (desc()->weights_desc.data_type == data_type::s8
154+
&& utils::one_of(desc()->src_desc.data_type,
155+
data_type::u8, data_type::s8)
156+
&& desc()->dst_desc.data_type
157+
!= data_type::f64)),
158+
VERBOSE_UNSUPPORTED_DT);
159+
160+
VDISPATCH_DECONVOLUTION_SC(
161+
init_convolution(engine), "init_convolution()");
162+
if (weights_md_.format_kind == format_kind::any) {
163+
VDISPATCH_DECONVOLUTION_SC(
164+
weights_axes_permutation(&weights_md_,
165+
conv_pd_->weights_md(), with_groups()),
166+
"weights_axes_permutation()");
167+
}
168+
if (src_md_.format_kind == format_kind::any)
169+
src_md_ = *conv_pd_->diff_dst_md();
170+
if (dst_md_.format_kind == format_kind::any)
171+
dst_md_ = *conv_pd_->diff_src_md();
172+
if (bias_md_.format_kind == format_kind::any) {
173+
VDISPATCH_DECONVOLUTION_SC(memory_desc_init_by_tag(bias_md_, x),
174+
VERBOSE_UNSUPPORTED_TAG);
175+
}
176+
init_name();
177+
init_scratchpad();
178+
VDISPATCH_DECONVOLUTION_SC(attr_.set_default_formats(dst_md(0)),
179+
VERBOSE_UNSUPPORTED_ATTR);
180+
181+
return status::success;
182+
}
183+
184+
std::shared_ptr<primitive_desc_t> conv_pd_;
185+
186+
private:
187+
std::string name_ = "conv:any";
188+
189+
void init_name() {
190+
name_.append("+");
191+
name_.append(conv_pd_->name());
192+
}
193+
194+
void init_scratchpad() {
195+
auto scratchpad = scratchpad_registry().registrar();
196+
scratchpad.book(memory_tracking::names::key_nested,
197+
conv_pd_->scratchpad_registry());
198+
}
199+
};
200+
201+
status_t init(impl::engine_t *engine) override {
202+
return create_nested_primitive(conv_p_, pd()->conv_pd_, engine);
203+
}
204+
205+
status_t execute(const exec_ctx_t &ctx) const override {
206+
using namespace memory_tracking::names;
207+
const auto &args = ctx.args();
208+
exec_args_t conv_args;
209+
conv_args[DNNL_ARG_DIFF_DST] = args.at(DNNL_ARG_SRC);
210+
conv_args[DNNL_ARG_WEIGHTS] = args.at(DNNL_ARG_WEIGHTS);
211+
conv_args[DNNL_ARG_DIFF_SRC] = args.at(DNNL_ARG_DST);
212+
if (pd()->with_bias())
213+
conv_args[DNNL_ARG_BIAS] = args.at(DNNL_ARG_BIAS);
214+
215+
for (int idx = 0; idx < pd()->attr()->post_ops_.len(); ++idx) {
216+
if (pd()->attr()->post_ops_.entry_[idx].is_binary()) {
217+
conv_args[DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_1]
218+
= args.at(DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx)
219+
| DNNL_ARG_SRC_1);
220+
} else if (pd()->attr()->post_ops_.entry_[idx].is_prelu()) {
221+
conv_args[DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx)
222+
| DNNL_ARG_WEIGHTS]
223+
= args.at(DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx)
224+
| DNNL_ARG_WEIGHTS);
225+
}
226+
}
227+
const auto z_src = DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC;
228+
const auto z_dst = DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST;
229+
if (args.find(z_src) != args.end()) conv_args[z_src] = args.at(z_src);
230+
if (args.find(z_dst) != args.end()) conv_args[z_dst] = args.at(z_dst);
231+
232+
for (int arg : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) {
233+
int key = DNNL_ARG_ATTR_SCALES | arg;
234+
if (args.find(key) != args.end()) conv_args[key] = args.at(key);
235+
}
236+
237+
exec_ctx_t conv_ctx(ctx, std::move(conv_args));
238+
239+
nested_scratchpad_t ns(ctx, key_nested, conv_p_);
240+
conv_ctx.set_scratchpad_grantor(ns.grantor());
241+
// Executing the convolution kernel
242+
return conv_p_->execute(conv_ctx);
243+
}
244+
245+
private:
246+
const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
247+
std::shared_ptr<impl::primitive_t> conv_p_;
248+
};
249+
250+
struct convolution_deconvolution_bwd_data_t : public gpu::primitive_t {
251+
using gpu::primitive_t::primitive_t;
252+
struct pd_t : public gpu_deconvolution_bwd_data_pd_t {
253+
pd_t(const deconvolution_desc_t *adesc, const primitive_attr_t *attr,
254+
const deconvolution_fwd_pd_t *hint_fwd_pd)
255+
: gpu_deconvolution_bwd_data_pd_t(adesc, attr, hint_fwd_pd)
256+
, conv_pd_(nullptr) {}
257+
258+
pd_t(const pd_t &other) = default;
259+
260+
~pd_t() = default;
261+
262+
DECLARE_COMMON_PD_T(
263+
name_.c_str(), convolution_deconvolution_bwd_data_t);
264+
265+
status_t init_convolution(impl::engine_t *engine) {
266+
convolution_desc_t cd;
267+
CHECK(conv_descr_create(desc(), &cd));
268+
primitive_attr_t conv_attr(*attr());
269+
if (!conv_attr.is_initialized()) return status::out_of_memory;
270+
primitive_desc_iterator_t it(
271+
engine, (op_desc_t *)&cd, &conv_attr, nullptr);
272+
if (!it.is_initialized()) return status::out_of_memory;
273+
conv_pd_ = *(++it);
274+
return (conv_pd_) ? status::success : status::unimplemented;
275+
}
276+
277+
status_t init(impl::engine_t *engine) {
278+
VDISPATCH_DECONVOLUTION(
279+
desc()->prop_kind == prop_kind::backward_data,
280+
VERBOSE_BAD_PROPKIND);
281+
282+
VDISPATCH_DECONVOLUTION(
283+
(utils::everyone_is(data_type::f32,
284+
desc()->diff_src_desc.data_type,
285+
desc()->weights_desc.data_type,
286+
desc()->diff_dst_desc.data_type)
287+
|| (utils::everyone_is(data_type::f64,
288+
desc()->diff_src_desc.data_type,
289+
desc()->weights_desc.data_type,
290+
desc()->diff_dst_desc.data_type))
291+
|| utils::everyone_is(data_type::f16,
292+
desc()->weights_desc.data_type,
293+
desc()->diff_dst_desc.data_type)
294+
|| utils::everyone_is(data_type::bf16,
295+
desc()->weights_desc.data_type,
296+
desc()->diff_dst_desc.data_type)),
297+
VERBOSE_UNSUPPORTED_DT);
298+
299+
VDISPATCH_DECONVOLUTION(
300+
utils::one_of(desc()->diff_src_desc.data_type,
301+
data_type::bf16, data_type::f16, data_type::f32,
302+
data_type::f64),
303+
VERBOSE_UNSUPPORTED_DT);
304+
VDISPATCH_DECONVOLUTION(
305+
desc()->alg_kind == alg_kind::deconvolution_direct,
306+
VERBOSE_BAD_ALGORITHM);
307+
VDISPATCH_DECONVOLUTION(
308+
attr()->has_default_values(), VERBOSE_UNSUPPORTED_ATTR);
309+
310+
VDISPATCH_DECONVOLUTION_SC(
311+
init_convolution(engine), "init_convolution()");
312+
if (weights_md_.format_kind == format_kind::any)
313+
VDISPATCH_DECONVOLUTION_SC(
314+
weights_axes_permutation(&weights_md_,
315+
conv_pd_->weights_md(), with_groups()),
316+
"weights_axes_permutation()");
317+
if (diff_src_md_.format_kind == format_kind::any)
318+
diff_src_md_ = *conv_pd_->dst_md();
319+
if (diff_dst_md_.format_kind == format_kind::any)
320+
diff_dst_md_ = *conv_pd_->src_md();
321+
322+
init_name();
323+
init_scratchpad();
324+
325+
return status::success;
326+
}
327+
328+
std::shared_ptr<primitive_desc_t> conv_pd_;
329+
330+
private:
331+
std::string name_ = "conv:any";
332+
333+
void init_name() {
334+
name_.append("+");
335+
name_.append(conv_pd_->name());
336+
}
337+
338+
void init_scratchpad() {
339+
auto scratchpad = scratchpad_registry().registrar();
340+
scratchpad.book(memory_tracking::names::key_nested,
341+
conv_pd_->scratchpad_registry());
342+
}
343+
};
344+
345+
status_t init(impl::engine_t *engine) override {
346+
return create_nested_primitive(conv_p_, pd()->conv_pd_, engine);
347+
}
348+
349+
status_t execute(const exec_ctx_t &ctx) const override {
350+
using namespace memory_tracking::names;
351+
const auto &args = ctx.args();
352+
exec_args_t conv_args;
353+
conv_args[DNNL_ARG_SRC] = args.at(DNNL_ARG_DIFF_DST);
354+
conv_args[DNNL_ARG_WEIGHTS] = args.at(DNNL_ARG_WEIGHTS);
355+
conv_args[DNNL_ARG_DST] = args.at(DNNL_ARG_DIFF_SRC);
356+
if (!types::is_zero_md(pd()->scratchpad_md()))
357+
conv_args[DNNL_ARG_SCRATCHPAD] = args.at(DNNL_ARG_SCRATCHPAD);
358+
exec_ctx_t conv_ctx(ctx, std::move(conv_args));
359+
360+
nested_scratchpad_t ns(ctx, key_nested, conv_p_);
361+
conv_ctx.set_scratchpad_grantor(ns.grantor());
362+
// Executing the convolution kernel
363+
return conv_p_->execute(conv_ctx);
364+
}
365+
366+
private:
367+
const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
368+
std::shared_ptr<impl::primitive_t> conv_p_;
369+
};
370+
371+
} // namespace generic
372+
} // namespace gpu
373+
} // namespace impl
374+
} // namespace dnnl
375+
376+
#endif

0 commit comments

Comments
 (0)