Skip to content

Commit 4a2d592

Browse files
dmitry-gorokhovazhai219
authored andcommitted
[FORK][FEATURE] Enabled BWD (JIT/GEMM) FP32/BF16 Convoltions + Depthwise post ops fusings
1 parent 849c55a commit 4a2d592

35 files changed

+818
-59
lines changed

src/cpu/gemm_convolution.cpp

+43
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,8 @@ status_t gemm_convolution_bwd_data_t::execute_backward_data_thr_nspc(
467467
// threads share work across mini-batch and groups
468468
const dim_t work_amount = jcp.ngroups * jcp.mb;
469469

470+
const auto &p = pd()->attr()->post_ops_;
471+
470472
data_t *__restrict col = scratchpad.get<data_t>(key_conv_gemm_col)
471473
+ (ptrdiff_t)ithr * jcp.im2col_sz;
472474
const bool acc_needed = jcp.ngroups > 1;
@@ -515,6 +517,25 @@ status_t gemm_convolution_bwd_data_t::execute_backward_data_thr_nspc(
515517
}
516518
});
517519
}
520+
if (p.len() > 0) {
521+
int depthwise_inj_idx = 0;
522+
for (int i = 0; i < p.len(); i++) {
523+
auto &post_op = p.entry_[i];
524+
if (post_op.is_depthwise()) {
525+
auto depthwise_weights = post_op.depthwise.weights_data;
526+
auto depthwise_bias = post_op.depthwise.biases_data;
527+
parallel_nd(static_cast<size_t>(jcp.is) * jcp.id, [&](size_t is) {
528+
data_t *__restrict diff_src_arr
529+
= diff_src + is * diff_src_os_stride;
530+
for (int ic = 0; ic < jcp.ic; ic++) {
531+
diff_src_arr[ic] = depthwise_injectors[depthwise_inj_idx]->compute_scalar(diff_src_arr[ic],
532+
depthwise_weights + g * jcp.ic + ic, depthwise_bias + g * jcp.ic + ic);
533+
}
534+
});
535+
depthwise_inj_idx++;
536+
}
537+
}
538+
}
518539
nd_iterator_step(n, jcp.mb, g, jcp.ngroups);
519540
}
520541
return status::success;
@@ -547,6 +568,8 @@ status_t gemm_convolution_bwd_data_t::execute_backward_data_ncsp(
547568
const dim_t work_amount = (size_t)jcp.ngroups * jcp.mb;
548569
const bool is_problem_3d = pd()->ndims() == 5;
549570

571+
const auto &p = pd()->attr()->post_ops_;
572+
550573
std::atomic<status_t> st(status::success);
551574
parallel(jcp.nthr, [&](const int ithr, const int nthr) {
552575
data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz;
@@ -594,6 +617,26 @@ status_t gemm_convolution_bwd_data_t::execute_backward_data_ncsp(
594617
}
595618
}
596619
}
620+
if (p.len() > 0) {
621+
int depthwise_inj_idx = 0;
622+
for (int i = 0; i < p.len(); i++) {
623+
auto &post_op = p.entry_[i];
624+
if (post_op.is_depthwise()) {
625+
auto depthwise_weights = post_op.depthwise.weights_data;
626+
auto depthwise_bias = post_op.depthwise.biases_data;
627+
parallel_nd(jcp.ic, [&](const int ic) {
628+
for (int id = 0; id < jcp.id; ++id) {
629+
data_t *d_ = _diff_src + ic * jcp.id * jcp.is + id * jcp.is;
630+
for (int iS = 0; iS < jcp.is; ++iS) {
631+
d_[iS] = depthwise_injectors[depthwise_inj_idx]->compute_scalar(d_[iS],
632+
depthwise_weights + g * jcp.ic + ic, depthwise_bias + g * jcp.ic + ic);
633+
}
634+
}
635+
});
636+
depthwise_inj_idx++;
637+
}
638+
}
639+
}
597640
nd_iterator_step(g, jcp.ngroups, n, jcp.mb);
598641
}
599642
});

src/cpu/gemm_convolution.hpp

+39-2
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
#include "cpu/gemm_convolution_utils.hpp"
2929
#include "cpu/primitive_attr_postops.hpp"
3030

31+
#include "ref_depthwise_injector.hpp"
32+
3133
namespace dnnl {
3234
namespace impl {
3335
namespace cpu {
@@ -156,7 +158,7 @@ struct gemm_convolution_bwd_data_t : public primitive_t {
156158
VERBOSE_BAD_ALGORITHM);
157159
VDISPATCH_CONV(!has_zero_dim_memory(), VERBOSE_EMPTY_TENSOR, "");
158160
VDISPATCH_CONV(
159-
attr()->has_default_values(), VERBOSE_UNSUPPORTED_ATTR);
161+
is_supported_post_ops(), VERBOSE_UNSUPPORTED_ATTR);
160162

161163
auto scratchpad = scratchpad_registry().registrar();
162164

@@ -166,9 +168,42 @@ struct gemm_convolution_bwd_data_t : public primitive_t {
166168
}
167169

168170
conv_gemm_conf_t jcp_;
171+
172+
protected:
173+
virtual bool is_supported_post_ops() const {
174+
const auto &p = this->attr()->post_ops_;
175+
if (p.len() > 1)
176+
return false;
177+
178+
auto all_post_ops_supported = [&]() {
179+
bool ok = true;
180+
181+
for (int i = 0; i < p.len(); i++) {
182+
ok = ok && utils::one_of(p.entry_[i].kind, primitive_kind::depthwise);
183+
}
184+
return ok;
185+
};
186+
187+
return all_post_ops_supported();
188+
}
169189
};
170190

171-
gemm_convolution_bwd_data_t(const pd_t *apd) : primitive_t(apd) {}
191+
192+
gemm_convolution_bwd_data_t(const pd_t *apd) : primitive_t(apd) {
193+
const auto &post_ops = pd()->attr()->post_ops_;
194+
for (int i = 0; i < post_ops.len(); i++) {
195+
auto &post_op = post_ops.entry_[i];
196+
if (post_op.is_depthwise()) {
197+
depthwise_injectors.push_back(new ref_depthwise_scalar_fwd_t(post_op.depthwise.alg));
198+
}
199+
}
200+
}
201+
202+
~gemm_convolution_bwd_data_t() {
203+
for (auto inj : depthwise_injectors)
204+
delete inj;
205+
depthwise_injectors.clear();
206+
}
172207

173208
typedef typename prec_traits<data_type::f32>::type data_t;
174209

@@ -187,6 +222,8 @@ struct gemm_convolution_bwd_data_t : public primitive_t {
187222
const memory_tracking::grantor_t &scratchpad) const;
188223

189224
const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
225+
226+
nstl::vector<ref_depthwise_scalar_fwd_t*> depthwise_injectors;
190227
};
191228

192229
struct gemm_convolution_bwd_weights_t : public primitive_t {

src/cpu/ref_convolution.cpp

+14
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,8 @@ status_t ref_convolution_bwd_data_t::execute_backward_data(
387387
return ds;
388388
};
389389

390+
const auto &p = pd()->attr()->post_ops_;
391+
390392
parallel_nd(G, MB, IC, ID, IH, IW,
391393
[&](dim_t g, dim_t mb, dim_t ic, dim_t id, dim_t ih, dim_t iw) {
392394
float ds = 0;
@@ -396,6 +398,18 @@ status_t ref_convolution_bwd_data_t::execute_backward_data(
396398
else
397399
ds += ker(g, mb, ic, id, ih, iw);
398400

401+
int depthwise_inj_idx = 0;
402+
for (int i = 0; i < p.len(); i++) {
403+
auto &post_op = p.entry_[i];
404+
if (post_op.is_depthwise()) {
405+
auto depthwise_weights = post_op.depthwise.weights_data;
406+
auto depthwise_bias = post_op.depthwise.biases_data;
407+
408+
ds = depthwise_injectors[depthwise_inj_idx]->compute_scalar(ds, depthwise_weights + g * IC + ic, depthwise_bias + g * IC + ic);
409+
}
410+
depthwise_inj_idx++;
411+
}
412+
399413
const auto diff_src_off = ref_conv_utils::get_data_off(
400414
diff_src_d, ndims, mb, g * IC + ic, id, ih, iw);
401415
io::store_float_value(

src/cpu/ref_convolution.hpp

+40-2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
#include "cpu/cpu_convolution_pd.hpp"
2828
#include "cpu/primitive_attr_postops.hpp"
2929

30+
#include "ref_depthwise_injector.hpp"
31+
3032
namespace dnnl {
3133
namespace impl {
3234
namespace cpu {
@@ -118,7 +120,9 @@ struct ref_convolution_bwd_data_t : public primitive_t {
118120
&& utils::one_of(diff_dst_type, f32, bf16, f16)
119121
&& wei_type == diff_dst_type
120122
&& utils::one_of(diff_src_type, f32, diff_dst_type)
121-
&& set_default_formats() && attr()->has_default_values();
123+
&& set_default_formats()
124+
&& attr()->has_default_values(primitive_attr_t::skip_mask_t::post_ops)
125+
&& is_supported_post_ops();
122126

123127
return ok ? status::success : status::unimplemented;
124128
}
@@ -132,9 +136,41 @@ struct ref_convolution_bwd_data_t : public primitive_t {
132136
: utils::pick(ndims() - 3, oiw, oihw, oidhw);
133137
return set_default_formats_common(dat_tag, wei_tag, dat_tag);
134138
}
139+
140+
bool is_supported_post_ops() const {
141+
const auto &p = this->attr()->post_ops_;
142+
if (p.len() > 1)
143+
return false;
144+
145+
auto all_post_ops_supported = [&]() {
146+
bool ok = true;
147+
148+
for (int i = 0; i < p.len(); i++) {
149+
ok = ok && utils::one_of(p.entry_[i].kind, primitive_kind::depthwise);
150+
}
151+
return ok;
152+
};
153+
154+
return all_post_ops_supported();
155+
}
135156
};
136157

137-
ref_convolution_bwd_data_t(const pd_t *apd) : primitive_t(apd) {}
158+
ref_convolution_bwd_data_t(const pd_t *apd) : primitive_t(apd) {
159+
const auto &post_ops = pd()->attr()->post_ops_;
160+
161+
for (int i = 0; i < post_ops.len(); i++) {
162+
auto &post_op = post_ops.entry_[i];
163+
if (post_op.is_depthwise()) {
164+
depthwise_injectors.push_back(new ref_depthwise_scalar_fwd_t(post_op.depthwise.alg));
165+
}
166+
}
167+
}
168+
169+
~ref_convolution_bwd_data_t() {
170+
for (auto inj : depthwise_injectors)
171+
delete inj;
172+
depthwise_injectors.clear();
173+
}
138174

139175
status_t execute(const exec_ctx_t &ctx) const override {
140176
return execute_backward_data(ctx);
@@ -143,6 +179,8 @@ struct ref_convolution_bwd_data_t : public primitive_t {
143179
private:
144180
status_t execute_backward_data(const exec_ctx_t &ctx) const;
145181
const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
182+
183+
nstl::vector<ref_depthwise_scalar_fwd_t*> depthwise_injectors;
146184
};
147185

148186
struct ref_convolution_bwd_weights_t : public primitive_t {

src/cpu/ref_depthwise_injector.cpp

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/*******************************************************************************
2+
* Copyright 2020 Intel Corporation
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*******************************************************************************/
16+
17+
#include "ref_depthwise_injector.hpp"
18+
19+
namespace dnnl {
20+
namespace impl {
21+
namespace cpu {
22+
23+
using namespace alg_kind;
24+
using namespace math;
25+
26+
template <typename T> inline T scale_shift_fwd(T s_val, T w_val, T b_val) {
27+
return s_val*w_val + b_val;
28+
}
29+
30+
template <typename T> inline T prelu_fwd(T s_val, T w_val) {
31+
return s_val >= 0 ? s_val : s_val*w_val;
32+
}
33+
34+
union float_raw {
35+
float f;
36+
unsigned short i[2];
37+
};
38+
39+
static float bf16tof32(bfloat16_t bf16) {
40+
union float_raw t = { 0 };
41+
t.i[1] = bf16;
42+
t.i[0] = 0;
43+
return t.f;
44+
}
45+
46+
static bfloat16_t f32tobf16(float f32) {
47+
union float_raw t = { 0 };
48+
t.f = f32;
49+
return t.i[1];
50+
}
51+
52+
inline bfloat16_t bf16_scale_shift_fwd(bfloat16_t s_val, bfloat16_t w_val, bfloat16_t b_val) {
53+
return f32tobf16(bf16tof32(s_val) * bf16tof32(w_val) + bf16tof32(b_val));
54+
}
55+
56+
inline bfloat16_t bf16_prelu_fwd(bfloat16_t s_val, bfloat16_t w_val) {
57+
return s_val >= 0 ? s_val : f32tobf16(bf16tof32(s_val) * bf16tof32(w_val));
58+
}
59+
60+
ref_depthwise_scalar_fwd_t::ref_depthwise_scalar_fwd_t(const alg_kind_t alg_)
61+
: alg(alg_) {
62+
using namespace alg_kind;
63+
64+
assert(utils::one_of(alg, depthwise_scale_shift, depthwise_prelu));
65+
}
66+
67+
float ref_depthwise_scalar_fwd_t::compute_scalar(float s, const float* weights, const float* bias) {
68+
switch (alg) {
69+
case depthwise_scale_shift: return scale_shift_fwd(s, *weights, *bias);
70+
case depthwise_prelu: return prelu_fwd(s, *weights);
71+
default: assert(!"unknown depthwise alg_kind");
72+
}
73+
74+
return 0.0f;
75+
}
76+
77+
} // namespace cpu
78+
} // namespace impl
79+
} // namespace dnnl

src/cpu/ref_depthwise_injector.hpp

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*******************************************************************************
2+
* Copyright 2020 Intel Corporation
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*******************************************************************************/
16+
17+
#ifndef REF_DEPTHWISE_INJECTOR_HPP
18+
#define REF_DEPTHWISE_INJECTOR_HPP
19+
20+
#include "common/primitive.hpp"
21+
#include "common/primitive_attr.hpp"
22+
23+
namespace dnnl {
24+
namespace impl {
25+
namespace cpu {
26+
27+
struct ref_depthwise_scalar_fwd_t {
28+
public:
29+
explicit ref_depthwise_scalar_fwd_t(alg_kind_t alg);
30+
float compute_scalar(float s, const float* weights, const float* bias);
31+
32+
private:
33+
alg_kind_t alg;
34+
};
35+
36+
} // namespace cpu
37+
} // namespace impl
38+
} // namespace dnnl
39+
40+
#endif

0 commit comments

Comments
 (0)