Skip to content

Commit 115e9fa

Browse files
dmitry-gorokhovluweizhou2016
authored andcommitted
[FORK][FEATURE] Updated sse41 jit convolutions to support padded channels
1 parent 63e956d commit 115e9fa

9 files changed

+67
-1
lines changed

src/cpu/x64/jit_sse41_1x1_conv_kernel_f32.cpp

+13
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,7 @@ status_t jit_sse41_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp,
567567
jcp.mb = src_d.dims()[0];
568568

569569
jcp.oc = dst_d.dims()[1] / jcp.ngroups;
570+
jcp.oc_without_padding = jcp.oc;
570571
jcp.ic = src_d.dims()[1] / jcp.ngroups;
571572

572573
jcp.ih = (ndims == 3) ? 1 : src_d.dims()[2];
@@ -645,6 +646,9 @@ status_t jit_sse41_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp,
645646

646647
const int simd_w = 4;
647648

649+
jcp.oc = rnd_up(jcp.oc, simd_w*2);
650+
jcp.ic = rnd_up(jcp.ic, simd_w*2);
651+
648652
jcp.ic_block = jcp.oc_block = simd_w * 2;
649653

650654
args_ok = true && jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0
@@ -810,6 +814,15 @@ status_t jit_sse41_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp,
810814
return status::success;
811815
}
812816

817+
void jit_sse41_1x1_conv_kernel_f32::init_scratchpad(
818+
memory_tracking::registrar_t &scratchpad,
819+
const jit_1x1_conv_conf_t &jcp) {
820+
using namespace dnnl::impl::memory_tracking::names;
821+
822+
if (jcp.prop_kind != backward_data && jcp.oc != jcp.oc_without_padding)
823+
scratchpad.book<float>(key_conv_padded_bias, sizeof(float) * jcp.oc);
824+
}
825+
813826
} // namespace x64
814827
} // namespace cpu
815828
} // namespace impl

src/cpu/x64/jit_sse41_1x1_conv_kernel_f32.hpp

+4
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#ifndef CPU_X64_JIT_SSE41_1X1_CONV_KERNEL_F32_HPP
1818
#define CPU_X64_JIT_SSE41_1X1_CONV_KERNEL_F32_HPP
1919

20+
#include "common/memory_tracking.hpp"
2021
#include "common/c_types_map.hpp"
2122
#include "common/memory.hpp"
2223

@@ -39,6 +40,9 @@ struct jit_sse41_1x1_conv_kernel_f32 : public jit_generator {
3940
const memory_desc_wrapper &dst_d, const primitive_attr_t &attr,
4041
int nthreads);
4142

43+
static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
44+
const jit_1x1_conv_conf_t &jcp);
45+
4246
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_1x1_conv_kernel_f32)
4347

4448
jit_1x1_conv_conf_t jcp;

src/cpu/x64/jit_sse41_1x1_convolution.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ namespace x64 {
3232

3333
using namespace dnnl::impl::status;
3434
using namespace dnnl::impl::utils;
35+
using namespace dnnl::impl::memory_tracking::names;
3536

3637
void jit_sse41_1x1_convolution_fwd_t::execute_forward(
3738
const exec_ctx_t &ctx) const {
@@ -52,6 +53,15 @@ void jit_sse41_1x1_convolution_fwd_t::execute_forward(
5253
: std::vector<const void *> {};
5354

5455
auto scratchpad = ctx.get_scratchpad_grantor();
56+
57+
if (pd()->wants_padded_bias()) {
58+
auto padded_bias = scratchpad.get<data_t>(key_conv_padded_bias);
59+
utils::array_copy(padded_bias, bias, kernel_->jcp.oc_without_padding);
60+
utils::array_set(padded_bias + kernel_->jcp.oc_without_padding, 0.f,
61+
kernel_->jcp.oc - kernel_->jcp.oc_without_padding);
62+
bias = padded_bias;
63+
}
64+
5565
parallel(kernel_->jcp.nthr, [&](const int ithr, const int nthr) {
5666
execute_forward_thr(ithr, nthr, src, weights, bias, weights_dw, bias_dw,
5767
dst, scratchpad, post_ops_binary_rhs_arg_vec.data(),

src/cpu/x64/jit_sse41_1x1_convolution.hpp

+3
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ struct jit_sse41_1x1_convolution_fwd_t : public primitive_t {
7373
dnnl_get_max_threads()));
7474
if (jcp_.with_dw_conv) CHECK(depthwise_po_init(engine));
7575

76+
auto scratchpad = scratchpad_registry().registrar();
77+
jit_sse41_1x1_conv_kernel_f32::init_scratchpad(scratchpad, jcp_);
78+
7679
return status::success;
7780
}
7881

src/cpu/x64/jit_sse41_conv_kernel_f32.cpp

+19
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ namespace x64 {
3535
using namespace dnnl::impl::format_tag;
3636
using namespace dnnl::impl::prop_kind;
3737
using namespace dnnl::impl::utils;
38+
using namespace dnnl::impl::memory_tracking::names;
3839

3940
using namespace Xbyak;
4041

@@ -398,6 +399,7 @@ status_t jit_sse41_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp,
398399
jcp.mb = src_d.dims()[0];
399400

400401
jcp.oc = dst_d.dims()[1] / jcp.ngroups;
402+
jcp.oc_without_padding = jcp.oc;
401403
jcp.ic = src_d.dims()[1] / jcp.ngroups;
402404

403405
jcp.ih = (ndims == 3) ? 1 : src_d.dims()[2];
@@ -491,7 +493,15 @@ status_t jit_sse41_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp,
491493
VDISPATCH_CONV_IC(channel_pad_ok, VERBOSE_UNSUPPORTED_PAD_FEATURE,
492494
"i/o and padded channel size mismatch");
493495

496+
bool ok_to_pad_channels = true && jcp.ngroups == 1;
497+
494498
const int simd_w = 8; // 2 SSE vectors processing at once
499+
if (ok_to_pad_channels) {
500+
jcp.oc = rnd_up(jcp.oc, simd_w);
501+
if (mimo) {
502+
jcp.ic = rnd_up(jcp.ic, simd_w);
503+
}
504+
}
495505

496506
jcp.ur_h = 1; /* no code-unrolling by h so far */
497507
jcp.ur_w = 3;
@@ -549,6 +559,15 @@ status_t jit_sse41_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp,
549559
return status::success;
550560
}
551561

562+
void jit_sse41_conv_fwd_kernel_f32::init_scratchpad(
563+
memory_tracking::registrar_t &scratchpad,
564+
const jit_conv_conf_t &jcp) {
565+
using namespace dnnl::impl::memory_tracking::names;
566+
567+
if (jcp.prop_kind != backward_data && jcp.oc != jcp.oc_without_padding)
568+
scratchpad.book<float>(key_conv_padded_bias, sizeof(float) * jcp.oc);
569+
}
570+
552571
} // namespace x64
553572
} // namespace cpu
554573
} // namespace impl

src/cpu/x64/jit_sse41_conv_kernel_f32.hpp

+4
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#ifndef CPU_X64_JIT_SSE41_CONV_KERNEL_F32_HPP
1818
#define CPU_X64_JIT_SSE41_CONV_KERNEL_F32_HPP
1919

20+
#include "common/memory_tracking.hpp"
2021
#include "common/c_types_map.hpp"
2122
#include "common/memory.hpp"
2223

@@ -39,6 +40,9 @@ struct jit_sse41_conv_fwd_kernel_f32 : public jit_generator {
3940
const memory_desc_wrapper &dst_d, const primitive_attr_t &attr,
4041
int nthreads);
4142

43+
static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
44+
const jit_conv_conf_t &jcp);
45+
4246
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_conv_fwd_kernel_f32)
4347
jit_conv_conf_t jcp;
4448
const primitive_attr_t &attr_;

src/cpu/x64/jit_sse41_convolution.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ namespace x64 {
2727

2828
using namespace dnnl::impl::status;
2929
using namespace dnnl::impl::utils;
30+
using namespace dnnl::impl::memory_tracking::names;
3031

3132
#define src_blk_off(f, n, c, h, w) \
3233
(pd()->ndims() == 3) ? (f).blk_off(n, c, w) : (f).blk_off(n, c, h, w)
@@ -60,6 +61,15 @@ void jit_sse41_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const {
6061
const bool is_dst_layout_nxc
6162
= one_of(jcp.dst_tag, format_tag::nwc, format_tag::nhwc);
6263

64+
auto scratchpad = ctx.get_scratchpad_grantor();
65+
if (pd()->wants_padded_bias()) {
66+
auto padded_bias = scratchpad.get<data_t>(key_conv_padded_bias);
67+
utils::array_copy(padded_bias, bias, kernel_->jcp.oc_without_padding);
68+
utils::array_set(padded_bias + kernel_->jcp.oc_without_padding, 0.f,
69+
kernel_->jcp.oc - kernel_->jcp.oc_without_padding);
70+
bias = padded_bias;
71+
}
72+
6373
parallel(jcp.nthr, [&](const int ithr, const int nthr) {
6474
assert(nthr == jcp.nthr);
6575

src/cpu/x64/jit_sse41_convolution.hpp

+3
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ struct jit_sse41_convolution_fwd_t : public primitive_t {
6262
*src_md(), *weights_md(), *dst_md(), *attr(),
6363
dnnl_get_max_threads()));
6464

65+
auto scratchpad = scratchpad_registry().registrar();
66+
jit_sse41_conv_fwd_kernel_f32::init_scratchpad(scratchpad, jcp_);
67+
6568
return status::success;
6669
}
6770

src/cpu/x64/jit_uni_dw_conv_kernel_utils.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ status_t jit_uni_dw_conv_fwd_kernel<isa, kernel_dt>::init_conf(
237237

238238
const bool ok_to_pad_channels = true && !is_data_layout_nxc
239239
&& jcp.oc == jcp.ngroups && jcp.ic == jcp.ngroups
240-
&& one_of(isa, avx512_core, avx2);
240+
&& one_of(isa, avx512_core, avx2, sse41);
241241
if (ok_to_pad_channels) {
242242
jcp.oc = rnd_up(jcp.oc, simd_w);
243243
jcp.ic = rnd_up(jcp.oc, simd_w);

0 commit comments

Comments
 (0)