Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

oneDNN 2.5 migration #121

Open
wants to merge 63 commits into
base: v2.5_for_ie_master
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
8109bd8
jit: avx512: conv: Fix missed ur_w iteration
AlexPeskov Mar 26, 2018
e248b50
Enable jit sse41 NxN convolution for grayscale input
Jun 5, 2018
1c6c17d
Support of strided blobs for [de]convolution and simple reorder
AlexPeskov Jan 23, 2018
94dbed3
Updated sse41 jit convolutions to support padded channels
Oct 26, 2018
a707196
Introduced Depthwise and Quantization post ops
Sep 24, 2020
d31c5c4
Pooling pads like Caffe
AlexPeskov Mar 5, 2018
21d958b
TBB_AUTO was enabled
alexey-varyzgin May 14, 2019
23dacbf
Add API function dnnl_memory_set_data_handle_no_pads_proc
AlexPeskov Aug 27, 2019
d22cb71
nchw_pooling dense fix
alexey-varyzgin Nov 14, 2019
b17c0cc
Enabled BWD (JIT/GEMM) FP32/BF16 Convoltions + Depthwise post ops fus…
Oct 21, 2020
d83837b
Fixes for MKLDNN to enable LTO
ilya-lavrenov May 18, 2020
8e0bbf7
[MSVC] Enabling SIMD functionality for VS2019
Aug 12, 2020
6a7071e
Avoid usage of undefined macro
AlexPeskov Oct 26, 2020
d13149c
Add several uni instruction wrappers into jit_generator
AlexPeskov Oct 26, 2020
11c81cd
Fix ODR violation
AlexPeskov Nov 16, 2020
b625aec
fix name matching with system strauct 'user' in llvm-android toolchain
AlexPeskov Nov 16, 2020
e333684
Added JIT FP32/BF16 Softmax for arbitrary inner_size
Dec 4, 2020
31cd484
Added support of hsigmoid, round_half_to_even, round_half_away_from_z…
a-sidorova Aug 27, 2020
c74c508
Limit applicability of is_1stconv logic for JIT FP32/BF16 AVX512 Conv…
Dec 9, 2020
3b61547
[WA] Removed kernel_outside_src condition on JIT FP32/BF16 Convolutions
Dec 9, 2020
83137dc
Added custom vesrion of JIT DW FP32/BF16 Convolution with 5D input su…
Dec 14, 2020
11141a1
Asymmetric quntization for activations
Nov 20, 2020
7d34a00
Added 3D DW case support for JIT INT8 Convolutions
Dec 14, 2020
fad9dbe
[WA] Disabled weights md transpose in FC to prevent perf degradations
Dec 16, 2020
8d279a1
Dynamic batch support via context
Jan 2, 2021
50043b9
Added JIT AVX512/AVX2 FP32 Planar Convolution implementation
Jan 2, 2021
ce960ad
Binary networks support
Jan 21, 2021
64232aa
Accommodating oneTBB (with hybrid cores support) that
myshevts Nov 24, 2020
daeb468
NCHW pooling perfomance fixed in accordance with v0.21
maxnick Feb 8, 2021
4592ba4
[WA] Fixed fallback on ref conv in case exceeding scratchpad limit
Feb 26, 2021
0c99230
Returned old behavior for fp32 avx2 1x1 conv with dw conv fusing
antonvor Feb 16, 2021
dcd0abf
Updated SoftPlus
a-sidorova Apr 12, 2021
9ccc627
Fixed warning fo undefined ITT_ARCH_IA64 (#52)
ilya-lavrenov May 12, 2021
ccd5353
Disable reorder JIT if both inputs and outputs are batch-strided.
IvanNovoselov Jun 8, 2021
94c0a20
Include TBB headers as system
AlexPeskov Oct 26, 2020
f1fd7e4
Fixed redifinition of tls model
Jun 24, 2021
a8f7910
nspc layout support for convolutions
maxnick Mar 31, 2021
597c659
set scale = 1.f in case of signed input on platforms without vnni
antonvor May 26, 2021
90b00ac
Enable direct copy primitives for u8 reorder
IvanNovoselov Jul 2, 2021
de66d63
Memory descriptor dynamism related changes
maxnick Jul 23, 2021
1501344
Added prelu as binary post op
antonvor Aug 2, 2021
a7c1712
Depthwise and Quantization post ops for Gemm Convolutions
antonvor Aug 23, 2021
12e2bca
Perf fixes for Ref and NCHW Poolings
antonvor Sep 1, 2021
ab4c85f
perf fixes for quantization post ops
antonvor Sep 16, 2021
80a89c2
todo: fix assert(idx < max_idx)
antonvor Sep 16, 2021
f5f763d
simple reorder: temporarily disabled zero padding
antonvor Sep 19, 2021
98064bb
Fix possible data race when accessing global reorder list
Sep 30, 2021
49d6a78
Brgemm implementation has perf degradation in RNN node
alexey-varyzgin Oct 4, 2021
994f9c5
reverted old behavior with pdim_consistent check due to perf problems
antonvor Nov 10, 2021
a5aa5bb
Renamed matmul kernel type: brg -> brgemm
Aug 24, 2021
a4ebe69
Fixed bias addition order in brgemm kernel
Nov 3, 2021
9f776d1
[1D] Enlarge support
alexey-varyzgin Oct 22, 2021
8ac9ff7
Update uni_ jit methods to avoid mixing vex and nonVEX instructions
Nov 17, 2021
26a09d4
Quantization post op structure modified to reduce its complexity
maxnick Nov 24, 2021
298546e
Hash utility functions were extracted to a separate module for reuse
maxnick Nov 29, 2021
c4d9021
Desc similar_to routine consider start stride
maxnick Jan 14, 2022
81698b9
Desc similar_to routine use stride cmp mask
maxnick Jan 26, 2022
f8d99a9
added some legacy parallel methods to fix perf issues
antonvor Jan 17, 2022
5a77d6e
Migrate legacy post ops and zero points on runtime data pointers
Jan 26, 2022
15c2025
Revert "all: remove mkldnn compatibility layer"
Feb 21, 2022
22e2744
Revert "reverted old behavior with pdim_consistent check due to perf …
Mar 5, 2022
5ac2a40
Fixed ODR violataion
Mar 14, 2022
d55288f
equality_uni_xxx_for_sse_and_avx
chenhu-wang Feb 24, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Depthwise and Quantization post ops for Gemm Convolutions
antonvor authored and dmitrygo committed Feb 20, 2022
commit a7c171216182a712ae30b54acefe04fb845dba98
126 changes: 13 additions & 113 deletions src/cpu/gemm_convolution.cpp
Original file line number Diff line number Diff line change
@@ -153,65 +153,18 @@ status_t gemm_convolution_fwd_t::execute_forward_thr_nspc(const exec_ctx_t &ctx,
&LDC);
if (st != status::success) return st;

if (jcp.with_bias || jcp.with_eltwise || jcp.with_binary) {
parallel(0, [&](int ithr, int nthr) {
dim_t start, end;
balance211(N * jcp.oc, nthr, ithr, start, end);

const size_t first_oc = start % jcp.oc;
const size_t last_oc = (end - 1) % jcp.oc;
const size_t first_os = start / jcp.oc;
const size_t last_os = (end - 1) / jcp.oc;

for (size_t os = first_os; os <= last_os; ++os) {
const size_t start_oc = (os == first_os) ? first_oc : 0;
const size_t end_oc
= (os == last_os) ? last_oc : jcp.oc - 1;

const data_t *__restrict bia_arr
= bia_base ? bia_base + g * jcp.oc : nullptr;
data_t *__restrict dst_arr = dst + os * dst_os_stride;

if (jcp.with_bias) {
PRAGMA_OMP_SIMD()
for (size_t oc = start_oc; oc <= end_oc; oc++) {
dst_arr[oc] += bia_arr[oc];
}
}
if (pp_kernel_) {
const size_t first_oc = g * jcp.oc;
const size_t last_oc = jcp.oc;
const size_t first_os = 0;
const size_t last_os = N;

if (jcp.with_eltwise || jcp.with_binary) {
bool fast_relu_done = false;
if (jcp.with_eltwise && jcp.post_ops.len() == 1) {
// fast branch for ReLU case
const auto &eltwise
= jcp.post_ops.entry_.back().eltwise;

if (eltwise.alg == alg_kind::eltwise_relu) {
const auto alpha = eltwise.alpha;
const auto scale = eltwise.scale;
PRAGMA_OMP_SIMD()
for (size_t oc = start_oc; oc <= end_oc;
oc++) {
if (dst_arr[oc] < 0)
dst_arr[oc] *= alpha;
dst_arr[oc] *= scale;
}
fast_relu_done = true;
}
}
if (!fast_relu_done) {
ref_post_ops_t::args_t args;
args.ctx = &ctx;
args.dst_md = pd()->dst_md();

for (size_t oc = start_oc; oc <= end_oc; oc++) {
args.l_offset = (g * jcp.oc + oc) * jcp.os;
post_ops_->execute(dst_arr[oc], args);
}
}
}
}
});
const data_t* bias = bia_base ? bia_base + g * jcp.oc: nullptr;

for (size_t os = first_os; os < last_os; ++os) {
data_t* dst_local = dst + os * dst_os_stride;
(*pp_kernel_)(dst_local, bias, 1, first_oc, last_oc, 1);
}
}
}
nd_iterator_step(n, MB, g, jcp.ngroups, ohb, nb_oh, owb, nb_ow);
@@ -306,61 +259,8 @@ status_t gemm_convolution_fwd_t::execute_forward_ncsp(
&LDA, _weights, &LDB, &beta, _dst, &M);
if (st != status::success) return st;

if (curr.ic == jcp.ic - step.ic) {
// TODO: for "outer threading" we have parallel section within
// outermost "parallel". It is not good. Consider to use
// "parallel" here with number of threads passed as parameter
const int oc_start = curr.g * jcp.oc + curr.oc;
if (jcp.with_eltwise || jcp.with_binary) {
bool fast_relu_done = false;
if (jcp.with_eltwise && jcp.post_ops.len() == 1) {
// fast branch for ReLU case
const auto &eltwise
= jcp.post_ops.entry_.back().eltwise;
if (eltwise.alg == alg_kind::eltwise_relu) {
parallel_nd(step.oc, [&](dim_t oc) {
data_t b = jcp.with_bias ? bias[oc_start + oc]
: 0;
data_t *d_ = _dst + oc * M;
PRAGMA_OMP_SIMD()
for (int oS = 0; oS < m; ++oS) {
d_[oS] += b;
if (d_[oS] < 0) d_[oS] *= eltwise.alpha;
d_[oS] *= eltwise.scale;
}
});
fast_relu_done = true;
}
}
if (!fast_relu_done) {
parallel_nd(step.oc, [&](dim_t oc) {
data_t b = jcp.with_bias ? bias[oc_start + oc] : 0;
data_t *d_ = _dst + oc * M;

ref_post_ops_t::args_t args;
args.ctx = &ctx;
args.dst_md = pd()->dst_md();
args.l_offset = d_ - dst;

PRAGMA_OMP_SIMD()
for (int oS = 0; oS < m; ++oS) {
d_[oS] += b;
post_ops_->execute(d_[oS], args);
args.l_offset++;
}
});
}

} else if (jcp.with_bias) {
parallel_nd(step.oc, [&](dim_t oc) {
data_t b = bias[oc_start + oc];
data_t *d_ = _dst + oc * M;
PRAGMA_OMP_SIMD()
for (int oS = 0; oS < m; ++oS) {
d_[oS] += b;
}
});
}
if (pp_kernel_ && curr.ic == jcp.ic - step.ic) {
(*pp_kernel_)(_dst, bias, m, curr.g * jcp.oc + curr.oc, step.oc, M);
}

return status::success;
44 changes: 27 additions & 17 deletions src/cpu/gemm_convolution.hpp
Original file line number Diff line number Diff line change
@@ -52,7 +52,6 @@ struct gemm_convolution_fwd_t : public primitive_t {
primitive_attr_t::skip_mask_t::post_ops, f32)
&& post_ops_ok();
if (!ok) return status::unimplemented;

auto scratchpad = scratchpad_registry().registrar();
return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad,
*desc(), src_md_, weights_md_, dst_md_, bias_md_, attr_,
@@ -63,35 +62,43 @@ struct gemm_convolution_fwd_t : public primitive_t {

protected:
bool post_ops_ok() const {
using namespace dnnl::impl::primitive_kind;
auto const &po = attr()->post_ops_;
auto is_eltwise
= [&](int idx) { return po.entry_[idx].is_eltwise(); };
auto is_sum = [&](int idx) { return po.entry_[idx].is_sum(); };
auto is_binary
= [&](int idx) { return po.entry_[idx].is_binary(); };

for (int idx = 0; idx < po.len(); idx++) {
bool ok = utils::one_of(true, is_sum(idx), is_binary(idx),
is_eltwise(idx))
&& IMPLICATION(is_sum(idx), idx == 0);
if (!ok) return false;
}

return true;
auto all_post_ops_supported = [&]() {
bool ok = true;

for (int i = 0; i < po.len(); i++) {
ok = ok && utils::one_of(po.entry_[i].kind, sum, eltwise, depthwise, quantization);
}
return ok;
};
auto contain = [&](dnnl::impl::primitive_kind_t kind) { return po.find(kind) != -1; };
auto position = [&](dnnl::impl::primitive_kind_t kind) { return po.find(kind); };
auto count = [&](dnnl::impl::primitive_kind_t kind) { return po.count(kind); };

return all_post_ops_supported() &&
count(primitive_kind::sum) <= 1 &&
IMPLICATION(contain(primitive_kind::sum), position(primitive_kind::sum) == 0);
}
};

gemm_convolution_fwd_t(const pd_t *apd)
: primitive_t(apd), post_ops_(nullptr) {}

status_t init(engine_t *engine) override {
const auto &post_ops = pd()->attr()->post_ops_;
const data_t one = 1.0, zero = 0.0;
const auto &jcp = pd()->jcp_;
beta_ = jcp.with_sum ? one : zero;

if (jcp.with_eltwise || jcp.with_binary)
CHECK(safe_ptr_assign(post_ops_, new ref_post_ops_t(jcp.post_ops)));
return status::success;
bool has_bias = pd()->with_bias();
bool has_post_ops = post_ops.len() > 0;
bool has_scale = !pd()->attr()->output_scales_.has_default_values();
postops_in_ip_ = has_bias || has_post_ops || has_scale;

CHECK(safe_ptr_assign(pp_kernel_, pp_kernel_t::create(pd(), pd()->jcp_)));
return (pp_kernel_) ? pp_kernel_->create_kernel() : status::success;
}

typedef typename prec_traits<data_type::f32>::type data_t;
@@ -110,6 +117,9 @@ struct gemm_convolution_fwd_t : public primitive_t {
const memory_tracking::grantor_t &scratchpad, int MB) const;
const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }

using pp_kernel_t = gemm_convolution_utils::pp_kernel_t;
std::unique_ptr<pp_kernel_t> pp_kernel_;
bool postops_in_ip_;
data_t beta_;

std::unique_ptr<ref_post_ops_t> post_ops_;
144 changes: 144 additions & 0 deletions src/cpu/gemm_convolution_utils.cpp
Original file line number Diff line number Diff line change
@@ -22,13 +22,18 @@
#include "common/type_helpers.hpp"
#include "common/utils.hpp"
#include "cpu/gemm_convolution_utils.hpp"

#include "ref_eltwise.hpp"
#include "ref_depthwise_injector.hpp"

#if DNNL_X64
#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
#endif

#include "cpu/platform.hpp"

#if DNNL_X64
#include "cpu/x64/jit_gemm_convolution_utils.hpp"
#include "cpu/x64/cpu_isa_traits.hpp"
#endif

@@ -50,6 +55,145 @@ single_gemm_conv_chunk_desc_t::single_gemm_conv_chunk_desc_t(dim_t d_off,
, w_off_(w_off)
, w_size_(w_size) {}

namespace gemm_convolution_utils {

struct ref_pp_kernel_t : pp_kernel_t {
ref_pp_kernel_t(const convolution_pd_t *pd, const conv_gemm_conf_t &jcp)
: pp_kernel_t(pd, jcp) {
for (int i = 0; i < post_ops_.len(); i++) {
auto &post_op = post_ops_.entry_[i];
if (post_op.is_eltwise()) {
ref_eltwise_injectors_.push_back(new ref_eltwise_scalar_fwd_t(post_op.eltwise));
} else if (post_op.is_depthwise()) {
ref_depthwise_injectors_.push_back(new ref_depthwise_scalar_fwd_t(
post_op.depthwise.alg));
}
}
}
~ref_pp_kernel_t() {
for (auto impl : ref_eltwise_injectors_)
delete impl;
ref_eltwise_injectors_.clear();
for (auto impl : ref_depthwise_injectors_)
delete impl;
ref_depthwise_injectors_.clear();
}

virtual void operator()(float *dst, const float *bias, const int len, const int oc_start, const int oc_work, const int oc_stride) const override;

private:
nstl::vector<ref_eltwise_scalar_fwd_t*> ref_eltwise_injectors_;
nstl::vector<ref_depthwise_scalar_fwd_t*> ref_depthwise_injectors_;
};

void ref_pp_kernel_t::operator()(float *dst, const float *bias, const int len,const int oc_start, const int oc_work, const int oc_stride) const {
// TODO: for "outer threading" we have parallel section within
// outermost "parallel". It is not good. Consider to use
// "parallel" here with number of threads passed as parameter
const auto &p = post_ops_;
bool need_bias = do_bias_;
if (p.len() > 0) {
int eltwise_inj_idx = 0;
int depthwise_inj_idx = 0;

for (int i = 0; i < p.len(); i++) {
auto &post_op = p.entry_[i];
// todo: sum?
if (post_op.is_eltwise()) {
parallel_nd(oc_work, [&](const int oc) {
float b = need_bias ? bias[oc_start + oc] : 0;
float *d_ = dst + oc * oc_stride;
for (int oS = 0; oS < len; ++oS) {
d_[oS] += b;
d_[oS] = ref_eltwise_injectors_[eltwise_inj_idx]->compute_scalar(d_[oS]);
}
});

eltwise_inj_idx++;
need_bias = false;
} else if (post_op.is_depthwise()) {
auto depthwise_weights = post_op.depthwise.weights_data;
auto depthwise_bias = post_op.depthwise.biases_data;

parallel_nd(oc_work, [&](const int oc) {
float b = need_bias ? bias[oc_start + oc] : 0;
float *d_ = dst + oc * oc_stride;
for (int oS = 0; oS < len; ++oS) {
d_[oS] += b;
d_[oS] = ref_depthwise_injectors_[depthwise_inj_idx]->compute_scalar(d_[oS],
depthwise_weights + oc_start + oc,
depthwise_bias + oc_start + oc);
}
});

depthwise_inj_idx++;
need_bias = false;
} else if (post_op.is_quantization()) {
auto quant = post_op.quantization;
auto pcl = quant.crop_low_data->shifts_;
auto pch = quant.crop_high_data->shifts_;
auto pisc = quant.input_scale_data->scales_;
auto pish = quant.input_shift_data->shifts_;
auto posc = quant.output_scale_data->scales_;
auto posh = quant.output_shift_data->shifts_;

parallel_nd(oc_work, [&](const int oc) {
float b = need_bias ? bias[oc_start + oc] : 0;
float *d_ = dst + oc * oc_stride;

int cl_idx = quant.crop_low_data->count_ == 1 ? 0 : oc_start + oc;
int ch_idx = quant.crop_high_data->count_ == 1 ? 0 : oc_start + oc;
int isc_idx = quant.input_scale_data->count_ == 1 ? 0 : oc_start + oc;
int ish_idx = quant.input_shift_data->count_ == 1 ? 0 : oc_start + oc;
int osc_idx = quant.output_scale_data->count_ == 1 ? 0 : oc_start + oc;
int osh_idx = quant.output_shift_data->count_ == 1 ? 0 : oc_start + oc;

PRAGMA_OMP_SIMD()
for (int oS = 0; oS < len; ++oS) {
d_[oS] += b;

d_[oS] = nstl::min(pch[ch_idx], nstl::max(pcl[cl_idx], d_[oS]));
d_[oS] = d_[oS] * pisc[isc_idx] + pish[ish_idx];
d_[oS] = roundf(d_[oS]);
d_[oS] = d_[oS] * posc[osc_idx] + posh[osh_idx];
}
});

need_bias = false;
}
}
}

if (need_bias) {
parallel_nd(oc_work, [&](const int oc) {
float b = bias[oc_start + oc];
float *d_ = dst + oc * oc_stride;
PRAGMA_OMP_SIMD()
for (int oS = 0; oS < len; ++oS) {
d_[oS] += b;
}
});
}
}

// Interface section

pp_kernel_t::pp_kernel_t(const convolution_pd_t *pd, const conv_gemm_conf_t &jcp)
: do_bias_(pd->with_bias()), post_ops_(pd->attr()->post_ops_) {}

pp_kernel_t *pp_kernel_t::create(
const convolution_pd_t *pd, const conv_gemm_conf_t &jcp) {
#if DNNL_X64
auto *res
= x64::gemm_convolution_utils::jit_pp_kernel_create(pd, jcp);
if (res) return res;
#endif

return new ref_pp_kernel_t(pd, jcp);
}

} // namespace gemm_convolution_utils

namespace jit_gemm_convolution_utils {

template <typename data_type_t>
10 changes: 2 additions & 8 deletions src/cpu/gemm_x8s8s32x_convolution.cpp
Original file line number Diff line number Diff line change
@@ -197,14 +197,8 @@ status_t gemm_x8s8s32x_convolution_fwd_t::execute_forward_thr(const int ithr,
= jcp.signed_input ? get_wei_comp(wei_base, wei_md) :
jcp.with_input_zp ? output_compensation_base : nullptr;

const bool should_apply_zp_src_comp_pad = jcp.zp.src_exists
&& jit_gemm_convolution_utils::padding_exists(jcp);
const bool should_apply_zp_src_comp_pad_jit_pp
= should_apply_zp_src_comp_pad
&& gemm_x8s8s32x_convolution_utils::mayiuse_jit_pp_kernel();
const bool should_apply_zp_src_comp_outside_pp
= should_apply_zp_src_comp_pad
&& !gemm_x8s8s32x_convolution_utils::mayiuse_jit_pp_kernel();
const bool should_apply_zp_src_comp_pad_jit_pp = false;
const bool should_apply_zp_src_comp_outside_pp = false;

dim_t g {0}, n {0}, ohb {0}, owb {0};
dim_t start = 0, end = 0;
27 changes: 22 additions & 5 deletions src/cpu/gemm_x8s8s32x_convolution.hpp
Original file line number Diff line number Diff line change
@@ -71,17 +71,18 @@ struct gemm_x8s8s32x_convolution_fwd_t : public primitive_t {
| primitive_attr_t::skip_mask_t::output_compensations
| primitive_attr_t::skip_mask_t::sum_dt,
dst_type)
&& attr()->post_ops_.check_sum_consistent_dt(dst_type)
&& output_scales_mask_ok() && zero_points_valid(attr());
// && attr()->post_ops_.check_sum_consistent_dt(dst_type)
&& output_scales_mask_ok() && zero_points_valid(attr())
&& post_ops_ok();
if (!ok) return status::unimplemented;

auto scratchpad = scratchpad_registry().registrar();
CHECK(jit_gemm_convolution_utils::init_conf(jcp_, scratchpad,
*desc(), src_md_, weights_md_, dst_md_, bias_md_, attr_,
dnnl_get_max_threads()));
if (!gemm_x8s8s32x_convolution_utils::post_ops_ok(
attr()->post_ops_, &dst_md_))
return status::unimplemented;
// if (!gemm_x8s8s32x_convolution_utils::post_ops_ok(
// attr()->post_ops_, &dst_md_))
// return status::unimplemented;
return status::success;
}

@@ -92,6 +93,22 @@ struct gemm_x8s8s32x_convolution_fwd_t : public primitive_t {
const auto &mask = attr()->output_scales_.mask_;
return mask == 0 || mask == 1 << 1;
}

bool post_ops_ok() const {
using namespace dnnl::impl::primitive_kind;
auto const &po = attr()->post_ops_;

auto all_post_ops_supported = [&]() {
bool ok = true;

for (int i = 0; i < po.len(); i++) {
ok = ok && utils::one_of(po.entry_[i].kind, sum, eltwise, depthwise, quantization);
}
return ok;
};

return all_post_ops_supported();
}
};

gemm_x8s8s32x_convolution_fwd_t(const pd_t *apd) : primitive_t(apd) {}
242 changes: 172 additions & 70 deletions src/cpu/gemm_x8s8s32x_convolution_utils.cpp
Original file line number Diff line number Diff line change
@@ -40,27 +40,41 @@ template <typename dst_data_t>
struct ref_pp_ker_t : pp_ker_t {
ref_pp_ker_t(const convolution_pd_t *pd, const conv_gemm_conf_t &jcp)
: pp_ker_t(pd, jcp) {
if (jcp.with_eltwise || jcp.with_binary) {
ref_post_ops_.reset(new ref_post_ops_t(jcp.post_ops));
for (int i = 0; i < post_ops_.len(); i++) {
auto &post_op = post_ops_.entry_[i];
if (post_op.is_eltwise()) {
ref_eltwise_injectors_.push_back(new ref_eltwise_scalar_fwd_t(post_op.eltwise));
} else if (post_op.is_depthwise()) {
ref_depthwise_injectors_.push_back(new ref_depthwise_scalar_fwd_t(
post_op.depthwise.alg));
}
}
}
~ref_pp_ker_t() {
for (auto impl : ref_eltwise_injectors_)
delete impl;
ref_eltwise_injectors_.clear();
for (auto impl : ref_depthwise_injectors_)
delete impl;
ref_depthwise_injectors_.clear();
}

using acc_data_t = pp_ker_t::acc_data_t;

void operator()(void *dst, const acc_data_t *acc, const char *bias,
void operator()(void *dst, acc_data_t *acc, const char *bias,
const float *scales, float sum_scale, float signed_scale, int g,
size_t start, size_t end, const zero_point_call_params_t &zp,
const void *post_ops_binary_rhs_arg_vec, const void *dst_orig,
const exec_ctx_t &ctx, const memory_desc_t &dst_md,
const single_gemm_conv_chunk_desc_t &chunk_desc) const override;

private:
std::unique_ptr<ref_post_ops_t> ref_post_ops_;
nstl::vector<ref_eltwise_scalar_fwd_t*> ref_eltwise_injectors_;
nstl::vector<ref_depthwise_scalar_fwd_t*> ref_depthwise_injectors_;
};

template <typename dst_data_t>
void ref_pp_ker_t<dst_data_t>::operator()(void *void_dst, const acc_data_t *acc,
const char *bias, const float *scales, float sum_scale,
void ref_pp_ker_t<dst_data_t>::operator()(void *void_dst, acc_data_t *acc, const char *bias, const float *scales, float sum_scale,
float signed_scale, int g, size_t start, size_t end,
const zero_point_call_params_t &zp,
const void * /* post_ops_binary_rhs_arg_vec */,
@@ -70,65 +84,177 @@ void ref_pp_ker_t<dst_data_t>::operator()(void *void_dst, const acc_data_t *acc,

if (end <= start) return;

assert(data_traits<dst_data_t>::data_type == jcp_.dst_data_type);

const lldiv_t dv_start = std::div((long long)start, (long long)jcp_.oc);
const lldiv_t dv_end = std::div((long long)(end - 1), (long long)jcp_.oc);
const size_t first_oc = dv_start.rem;
const size_t last_oc = dv_end.rem;
const size_t first_os = dv_start.quot;
const size_t last_os = dv_end.quot;
const int32_t zp_dst_val = jcp_.zp.dst_exists ? *(zp.dst) : 0;
assert(data_traits<dst_data_t>::data_type == dst_data_type_);
dst_data_t *dst = (dst_data_t *)void_dst;

ref_post_ops_t::args_t args;
args.ctx = &ctx;
args.dst_md = &dst_md;
const size_t first_oc = start % OC_;
const size_t last_oc = (end - 1) % OC_;
const size_t first_os = start / OC_;
const size_t last_os = (end - 1) / OC_;
if (post_ops_.len() == 0) {
for (size_t os = first_os; os <= last_os; os++) {
const size_t start_oc = (os == first_os) ? first_oc : 0;
const size_t end_oc = (os == last_os) ? last_oc : OC_ - 1;
for (size_t oc = start_oc; oc <= end_oc; oc++) {
const size_t acc_off = os * jcp_.oc + oc;
const size_t dst_off = os * dst_os_stride_ + oc;

for (size_t os = first_os; os <= last_os; os++) {
const size_t start_oc = (os == first_os) ? first_oc : 0;
const size_t end_oc = (os == last_os) ? last_oc : jcp_.oc - 1;
for (size_t oc = start_oc; oc <= end_oc; oc++) {
const size_t acc_off = os * jcp_.oc + oc;
const size_t dst_off = os * jcp_.dst_os_stride + oc;
float d = (float) (acc[acc_off]);
if (jcp_.signed_input) d *= signed_scale;

int32_t data_s32 = acc[acc_off];
if (do_bias_)
d += math::get_bias(bias, g * jcp_.oc + oc, bias_data_type_);

if (jcp_.zp.src_exists) {
const auto oc_offset = g * jcp_.oc + oc;
data_s32 += zp.src_comp[oc_offset];
d *= scales[(g * jcp_.oc + oc) * scale_idx_mult_];
dst[dst_off] = qz_a1b0<float, dst_data_t>()(d);
}
}
} else {
float* acc_fp = reinterpret_cast<float*>(acc);

float data = static_cast<float>(data_s32);
auto load = [&](int idx, size_t oc, size_t os, size_t acc_off, size_t dst_off) {
float d;
if (idx == 0) {
d = (float) (acc[acc_off]);

if (jcp_.signed_input) data *= signed_scale;
if (jcp_.signed_input)
d *= signed_scale;

if (jcp_.with_bias) {
const float b = io::load_float_value(
jcp_.bias_data_type, bias, g * jcp_.oc + oc);
data += b;
}
if (do_bias_)
d += math::get_bias(bias, g * jcp_.oc + oc,
bias_data_type_);

data *= scales[(g * jcp_.oc + oc) * jcp_.scale_idx_mult];
if (jcp_.with_sum)
data += sum_scale
* io::load_float_value(
jcp_.sum_data_type, void_dst, dst_off);
if (jcp_.with_eltwise || jcp_.with_binary) {
args.l_offset = (g * jcp_.oc + oc) * jcp_.os;
ref_post_ops_->execute(data, args);
d *= scales[(g * jcp_.oc + oc) * scale_idx_mult_];
} else {
d = acc_fp[acc_off];
}

if (jcp_.zp.dst_exists) data += zp_dst_val;
return d;
};

auto store = [&](int idx, float d, size_t acc_off, size_t dst_off) {
if (idx == post_ops_.len() - 1)
dst[dst_off] = qz_a1b0<float, dst_data_t>()(d);
else
acc_fp[acc_off] = d;
};

int eltwise_inj_idx = 0;
int depthwise_inj_idx = 0;
for (int i = 0; i < post_ops_.len(); i++) {
auto &post_op = post_ops_.entry_[i];
if (post_op.is_eltwise()) {
for (size_t os = first_os; os <= last_os; os++) {
const size_t start_oc = (os == first_os) ? first_oc : 0;
const size_t end_oc = (os == last_os) ? last_oc : OC_ - 1;
for (size_t oc = start_oc; oc <= end_oc; oc++) {
const size_t acc_off = os * jcp_.oc + oc;
const size_t dst_off = os * this->dst_os_stride_ + oc;

float d = load(i, oc, os, acc_off, dst_off);

d = ref_eltwise_injectors_[eltwise_inj_idx]->compute_scalar(d);

store(i, d, acc_off, dst_off);
}
}
eltwise_inj_idx++;
} else if (post_op.is_depthwise()) {
for (size_t os = first_os; os <= last_os; os++) {
const size_t start_oc = (os == first_os) ? first_oc : 0;
const size_t end_oc = (os == last_os) ? last_oc : OC_ - 1;
for (size_t oc = start_oc; oc <= end_oc; oc++) {
const size_t acc_off = os * jcp_.oc + oc;
const size_t dst_off = os * this->dst_os_stride_ + oc;

auto depthwise_weights = post_op.depthwise.weights_data;
auto depthwise_bias = post_op.depthwise.biases_data;

float d = load(i, oc, os, acc_off, dst_off);

d = ref_depthwise_injectors_[depthwise_inj_idx]->compute_scalar(d, depthwise_weights + g * jcp_.oc + oc,
depthwise_bias + g * jcp_.oc + oc);

store(i, d, acc_off, dst_off);
}
}
depthwise_inj_idx++;
} else if (post_op.is_quantization()) {
for (size_t os = first_os; os <= last_os; os++) {
const size_t start_oc = (os == first_os) ? first_oc : 0;
const size_t end_oc = (os == last_os) ? last_oc : OC_ - 1;
for (size_t oc = start_oc; oc <= end_oc; oc++) {
const size_t acc_off = os * jcp_.oc + oc;
const size_t dst_off = os * this->dst_os_stride_ + oc;

auto quant = post_op.quantization;
auto pcl = quant.crop_low_data->shifts_;
auto pch = quant.crop_high_data->shifts_;
auto pisc = quant.input_scale_data->scales_;
auto pish = quant.input_shift_data->shifts_;
auto posc = quant.output_scale_data->scales_;
auto posh = quant.output_shift_data->shifts_;

float d = load(i, oc, os, acc_off, dst_off);

int cl_idx = quant.crop_low_data->count_ == 1 ? 0 : g * jcp_.oc + oc;
int ch_idx = quant.crop_high_data->count_ == 1 ? 0 : g * jcp_.oc + oc;
int isc_idx = quant.input_scale_data->count_ == 1 ? 0 : g * jcp_.oc + oc;
int ish_idx = quant.input_shift_data->count_ == 1 ? 0 : g * jcp_.oc + oc;
int osc_idx = quant.output_scale_data->count_ == 1 ? 0 : g * jcp_.oc + oc;
int osh_idx = quant.output_shift_data->count_ == 1 ? 0 : g * jcp_.oc + oc;

d = nstl::min(pch[ch_idx], nstl::max(pcl[cl_idx], d));
d = d * pisc[isc_idx] + pish[ish_idx];
d = roundf(d);
d = d * posc[osc_idx] + posh[osh_idx];

store(i, d, acc_off, dst_off);
}
}
} else if (post_op.is_sum()) {
for (size_t os = first_os; os <= last_os; os++) {
const size_t start_oc = (os == first_os) ? first_oc : 0;
const size_t end_oc = (os == last_os) ? last_oc : OC_ - 1;
for (size_t oc = start_oc; oc <= end_oc; oc++) {
const size_t acc_off = os * jcp_.oc + oc;
const size_t dst_off = os * this->dst_os_stride_ + oc;

io::store_float_value(jcp_.dst_data_type, data, void_dst, dst_off);
float d = load(i, oc, os, acc_off, dst_off);

d += post_op.sum.scale * math::get_sum((char *) dst, dst_off, post_op.sum.dt);

store(i, d, acc_off, dst_off);
}
}
}
}
}
}

// Interface section

pp_ker_t::pp_ker_t(const convolution_pd_t *pd, const conv_gemm_conf_t &jcp)
: jcp_(jcp) {}
: jcp_(jcp)
, post_ops_(pd->attr()->post_ops_)
, OC_(jcp_.oc)
{
const auto dst_md = memory_desc_wrapper(pd->dst_md());

dst_os_stride_ = dst_md.blocking_desc().strides[pd->ndims() - 1];
dst_data_type_ = dst_md.data_type();

do_scale_ = !pd->attr()->output_scales_.has_default_values();
if (do_scale_) {
scale_idx_mult_ = (pd->attr()->output_scales_.mask_ == (1 << 1));
}

do_bias_ = pd->with_bias();
if (do_bias_) {
bias_data_type_ = pd->desc()->bias_desc.data_type;
assert(bias_data_type_ != data_type::undef);
}
}

pp_ker_t *pp_ker_t::create(
const convolution_pd_t *pd, const conv_gemm_conf_t &jcp) {
@@ -148,30 +274,6 @@ pp_ker_t *pp_ker_t::create(
return nullptr;
}

bool post_ops_ok(const post_ops_t &post_ops, const memory_desc_wrapper *dst_d) {
#if DNNL_X64
return x64::gemm_x8s8s32x_convolution_utils::post_ops_ok(post_ops, dst_d);
#endif
return std::all_of(post_ops.entry_.cbegin(), post_ops.entry_.cend(),
[](const dnnl_post_ops::entry_t &post_op) {
return post_op.is_eltwise() || post_op.is_sum()
|| post_op.is_binary();
});
}

bool post_ops_ok(const post_ops_t &post_ops, const memory_desc_t *dst_d) {
const auto dst_md = memory_desc_wrapper(dst_d);
return post_ops_ok(post_ops, &dst_md);
}

bool mayiuse_jit_pp_kernel() noexcept {
#if DNNL_X64
return x64::gemm_x8s8s32x_convolution_utils::mayiuse_jit_pp_kernel();
#else
return false;
#endif
}

} // namespace gemm_x8s8s32x_convolution_utils
} // namespace cpu
} // namespace impl
17 changes: 12 additions & 5 deletions src/cpu/gemm_x8s8s32x_convolution_utils.hpp
Original file line number Diff line number Diff line change
@@ -34,24 +34,31 @@ struct pp_ker_t {

typedef typename prec_traits<data_type::s32>::type acc_data_t;

virtual void operator()(void *dst, const acc_data_t *acc, const char *bias,
virtual void operator()(void *dst, acc_data_t *acc, const char *bias,
const float *scales, float sum_scale, float signed_scale, int g,
size_t start, size_t end, const zero_point_call_params_t &zp,
const void *post_ops_binary_rhs_arg_vec, const void *dst_orig,
const exec_ctx_t &ctx, const memory_desc_t &dst_md,
const single_gemm_conv_chunk_desc_t &chunk_desc) const = 0;

size_t dst_os_stride_;

virtual status_t create_kernel() { return status::success; }

protected:
pp_ker_t(const convolution_pd_t *pd, const conv_gemm_conf_t &jcp);

const conv_gemm_conf_t &jcp_;
};
const post_ops_t &post_ops_;
size_t OC_;

bool post_ops_ok(const post_ops_t &post_ops, const memory_desc_wrapper *dst_d);
bool post_ops_ok(const post_ops_t &post_ops, const memory_desc_t *dst_d);
bool mayiuse_jit_pp_kernel() noexcept;
bool do_bias_ = false;
bool do_scale_ = false;
size_t scale_idx_mult_ = 0;

data_type_t bias_data_type_ = data_type::undef;
data_type_t dst_data_type_ = data_type::undef;
};

} // namespace gemm_x8s8s32x_convolution_utils
} // namespace cpu
111 changes: 61 additions & 50 deletions src/cpu/x64/gemm_bf16_convolution.cpp
Original file line number Diff line number Diff line change
@@ -72,41 +72,31 @@ void cvt_acc_to_dst(const conv_gemm_conf_t &jcp, size_t g_start, size_t g_end,
template <data_type_t dst_data_type>
gemm_bf16_convolution_fwd_t<dst_data_type>::pp_ker_t::pp_ker_t(const pd_t *pd)
: jcp_(pd->jcp_)
, post_ops_(pd->attr()->post_ops_)
, do_sum_(dst_data_type != data_type::f32 && jcp_.with_sum)
, max_data_reg_idx_(31)
, max_unroll_(12)
, compute_reg_step_(1)
, data_reg_base_idx_(0) {
, data_reg_base_idx_(0)
, attr_(pd->attr())
, jit_eltwise_injectors_(0)
{
using namespace types;
using namespace Xbyak;

if (!mayiuse(avx512_core))
// bf16 is not supported
return;

const auto &post_ops = jcp_.post_ops;
if (jcp_.with_eltwise || jcp_.with_binary) {
#define PARAM_OFF(field) offsetof(ker_args, field)
static constexpr bool preserve_gpr = true;
static constexpr bool preserve_vmm = true;
static constexpr size_t helper_vmm_idx = 31;
static constexpr size_t tail_size = 1;
static constexpr bool use_exact_tail_scalar_bcast = false;
const binary_injector::rhs_arg_static_params_t rhs_arg_static_params {
helper_vmm_idx, reserved_eltwise_gpr, r14, preserve_gpr,
preserve_vmm, PARAM_OFF(post_ops_binary_rhs_arg_vec),
PARAM_OFF(dst_orig), memory_desc_wrapper(pd->dst_md()),
tail_size, kreg_rem_mask, use_exact_tail_scalar_bcast};
const binary_injector::static_params_t binary_static_params {
this->reg_param, rhs_arg_static_params};
static constexpr bool save_state = true;
const eltwise_injector::static_params_t eltwise_static_params {
save_state, reserved_eltwise_gpr, reserved_eltwise_maskr};

postops_injector_ = utils::make_unique<
injector::jit_uni_postops_injector_t<avx512_core>>(
this, post_ops, binary_static_params, eltwise_static_params);
#undef PARAM_OFF
bool do_depthwise_ = false;
for (int i = 0; i < post_ops_.len(); i++) {
auto& post_op = post_ops_.entry_[i];
if (post_op.is_eltwise()) {
jit_eltwise_injectors_.push_back(new jit_uni_eltwise_injector_f32<avx512_common>(this,
post_op.eltwise, true, reserved_eltwise_gpr, reserved_eltwise_maskr));
} else if (post_op.is_depthwise()) {
do_depthwise_ = true;
}
}

if (do_sum_) {
@@ -116,6 +106,9 @@ gemm_bf16_convolution_fwd_t<dst_data_type>::pp_ker_t::pp_ker_t(const pd_t *pd)

if (jcp_.with_bias) vreg_bias = Zmm(data_reg_base_idx_++);

if (do_depthwise_)
vreg_dw = Zmm(data_reg_base_idx_++);

vlen_ = cpu_isa_traits<avx512_core>::vlen / sizeof(float);

isa_ = mayiuse(avx512_core_bf16) ? avx512_core_bf16
@@ -132,25 +125,6 @@ gemm_bf16_convolution_fwd_t<dst_data_type>::pp_ker_t::pp_ker_t(const pd_t *pd)
= (max_data_reg_idx_ - data_reg_base_idx_ + 1) / compute_reg_step_;
}

template <data_type_t dst_data_type>
void gemm_bf16_convolution_fwd_t<dst_data_type>::pp_ker_t::apply_postops(
const bool apply_mask, const size_t out_offset, const int vmm_idx) {
#define PARAM_OFF(x) offsetof(ker_args, x)
if (jcp_.with_eltwise || jcp_.with_binary) {
if (jcp_.with_binary) {
binary_injector::rhs_arg_dynamic_params_t rhs_arg_params;
rhs_arg_params.vmm_idx_to_out_reg.emplace(vmm_idx, reg_dst);
rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace(
vmm_idx, out_offset * sizeof(dst_data_t));
if (apply_mask) rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx);

postops_injector_->compute_vector(vmm_idx, rhs_arg_params);
} else
postops_injector_->compute_vector(vmm_idx);
}
#undef PARAM_OFF
}

template <data_type_t dst_data_type>
void gemm_bf16_convolution_fwd_t<dst_data_type>::pp_ker_t::generate() {
using namespace Xbyak;
@@ -171,6 +145,8 @@ void gemm_bf16_convolution_fwd_t<dst_data_type>::pp_ker_t::generate() {
mov(reg_len, ptr[reg_param + PARAM_OFF(spatial_length)]);
mov(reg_oc_iter, ptr[reg_param + PARAM_OFF(oc_work)]);

mov(reg_oc_offset, ptr[reg_param + PARAM_OFF(oc_offset)]);

if (jcp_.with_binary) {
// zero initialize binary post_ops offset accumulator (store on stack)
const auto binary_post_op_acc_off_reg = reg_tmp;
@@ -217,7 +193,38 @@ void gemm_bf16_convolution_fwd_t<dst_data_type>::pp_ker_t::generate() {
vfmadd231ps(vreg_dst(idx), vreg_prev_dst(idx), vreg_sum_scale);
}

apply_postops(apply_mask, offset, vreg_dst_idx(idx));
int eltwise_inj_idx = 0;
const auto& p = attr_->post_ops_;
for (int i = 0; i < p.len(); i++) {
auto& post_op = p.entry_[i];
if (post_op.is_eltwise()) {
jit_eltwise_injectors_[eltwise_inj_idx]->compute_vector(vreg_dst_idx(idx));
eltwise_inj_idx++;
} else if (post_op.is_depthwise()) {
mov(reg_dw, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
lea(reg_dw, ptr[reg_dw + reg_oc_offset]);

switch (post_op.depthwise.alg) {
case alg_kind::depthwise_scale_shift: {
vbroadcastss(vreg_dw, ptr[reg_dw]);
vmulps(vreg_dst(idx), vreg_dst(idx), vreg_dw);
mov(reg_dw, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
lea(reg_dw, ptr[reg_dw + reg_oc_offset]);
vbroadcastss(vreg_dw, ptr[reg_dw]);
vaddps(vreg_dst(idx), vreg_dst(idx), vreg_dw);
break;
}
case alg_kind::depthwise_prelu: {
vpxord(vreg_dw, vreg_dw, vreg_dw);
vcmpps(kmask, vreg_dst(idx), vreg_dw, _cmp_lt_os);
vbroadcastss(vreg_dw, ptr[reg_dw]);
vmulps(vreg_dst(idx) | kmask, vreg_dst(idx), vreg_dw);
break;
}
default: assert(!"unsupported depthwise algorithm");
}
}
}

if (dst_data_type == data_type::bf16) {
// TODO: implement store by zmm registers for bf16
@@ -293,6 +300,8 @@ void gemm_bf16_convolution_fwd_t<dst_data_type>::pp_ker_t::generate() {
if (jcp_.with_binary)
inc(EVEX_compress_addr(rsp, reg_binary_post_op_acc_off));

add(reg_oc_offset, sizeof(float));

dec(reg_oc_iter);
jnz(oc_loop, T_NEAR); // oc_loop end

@@ -302,14 +311,15 @@ void gemm_bf16_convolution_fwd_t<dst_data_type>::pp_ker_t::generate() {

postamble();

if (jcp_.with_eltwise) postops_injector_->prepare_table();
for (auto& inj : jit_eltwise_injectors_)
inj->prepare_table();
}

// operator () specialized for nspc format
template <data_type_t dst_data_type>
void gemm_bf16_convolution_fwd_t<dst_data_type>::pp_ker_t::operator()(
dst_data_t *dst, const acc_data_t *acc, const acc_data_t *bias,
float sum_scale, size_t oc_work,
float sum_scale, size_t oc_work, size_t g_offset,
const void *post_ops_binary_rhs_arg_vec, const void *dst_orig,
const size_t g_oc_offset) {

@@ -322,6 +332,7 @@ void gemm_bf16_convolution_fwd_t<dst_data_type>::pp_ker_t::operator()(
args.acc_stride_in_bytes = sizeof(acc_data_t);
args.spatial_length = 1;
args.oc_work = oc_work;
args.oc_offset = g_offset * sizeof(float);

args.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec;
args.dst_orig = dst_orig;
@@ -333,7 +344,7 @@ void gemm_bf16_convolution_fwd_t<dst_data_type>::pp_ker_t::operator()(
template <data_type_t dst_data_type>
void gemm_bf16_convolution_fwd_t<dst_data_type>::pp_ker_t::operator()(
dst_data_t *dst, const acc_data_t *acc, const acc_data_t *bias,
float sum_scale, size_t dst_stride_in_elements,
size_t g_offset, size_t start_oc, float sum_scale, size_t dst_stride_in_elements,
size_t acc_stride_in_elements, size_t sp_len, size_t oc_len,
const void *post_ops_binary_rhs_arg_vec, const void *dst_orig,
const size_t g_oc_offset) {
@@ -348,6 +359,7 @@ void gemm_bf16_convolution_fwd_t<dst_data_type>::pp_ker_t::operator()(
args.acc_stride_in_bytes = acc_stride_in_elements * sizeof(acc_data_t);
args.spatial_length = sp_len;
args.oc_work = oc_len;
args.oc_offset = (start_oc + g_offset) * sizeof(float);

args.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec;
args.dst_orig = dst_orig;
@@ -510,7 +522,7 @@ status_t gemm_bf16_convolution_fwd_t<dst_data_type>::execute_forward_thr_nspc(

(*pp_ker_)(dst_arr,
acc_needed ? acc_arr : (float *)dst_arr,
bia_arr, sum_scale, jcp.oc,
bia_arr, sum_scale, jcp.oc, g * jcp.oc,
post_ops_binary_rhs_arg_vec, dst_base,
g * jcp.oc);
});
@@ -620,8 +632,7 @@ status_t gemm_bf16_convolution_fwd_t<dst_data_type>::execute_forward_ncsp(
if (this->pd()->is_postprocess_required() && ic + ic_block >= jcp.ic) {
size_t acc_str = LDC;
size_t dst_str = M;
float *bias_ptr = bias ? bias + groups * jcp.oc + oc : nullptr;
(*pp_ker_)(dst_local, acc, bias_ptr, sum_scale, dst_str, acc_str, m,
(*pp_ker_)(dst_local, acc, bias, groups * jcp.oc, oc, sum_scale, dst_str, acc_str, m,
oc_block, post_ops_binary_rhs_arg_vec.data(), dst,
groups * jcp.oc + oc);
}
58 changes: 43 additions & 15 deletions src/cpu/x64/gemm_bf16_convolution.hpp
Original file line number Diff line number Diff line change
@@ -57,17 +57,8 @@ struct gemm_bf16_convolution_fwd_t : public primitive_t {
&& !has_zero_dim_memory()
&& attr()->has_default_values(
primitive_attr_t::skip_mask_t::post_ops,
dst_data_type);
{
using namespace x64::injector;
static constexpr bool sum_at_pos_0_only = true;
static constexpr bool sum_requires_scale_one = true;
static constexpr bool sum_requires_zp_zero = true;
const auto dst_md = memory_desc_wrapper(dst_md_);
ok &= post_ops_ok({avx512_core, {binary, eltwise, sum},
attr()->post_ops_, &dst_md, sum_at_pos_0_only,
sum_requires_scale_one, sum_requires_zp_zero});
}
dst_data_type)
&& post_ops_ok();
if (!ok) return status::unimplemented;

auto scratchpad = scratchpad_registry().registrar();
@@ -89,6 +80,29 @@ struct gemm_bf16_convolution_fwd_t : public primitive_t {
}

conv_gemm_conf_t jcp_;

protected:
virtual bool post_ops_ok() const {
auto const &po = this->attr()->post_ops_;
auto all_post_ops_supported = [&]() {
bool ok = true;

for (int i = 0; i < po.len(); i++) {
ok = ok && utils::one_of(po.entry_[i].kind, primitive_kind::sum, primitive_kind::eltwise, primitive_kind::depthwise);
}
return ok;
};

auto contain = [&](dnnl::impl::primitive_kind_t kind) { return po.find(kind) != -1; };
auto position = [&](dnnl::impl::primitive_kind_t kind) { return po.find(kind); };
auto count = [&](dnnl::impl::primitive_kind_t kind) { return po.count(kind); };

return all_post_ops_supported() &&
count(primitive_kind::sum) <= 1 &&
IMPLICATION(contain(primitive_kind::sum), position(primitive_kind::sum) == 0);

return false;
}
};

gemm_bf16_convolution_fwd_t(const pd_t *apd)
@@ -135,12 +149,19 @@ struct gemm_bf16_convolution_fwd_t : public primitive_t {
DECLARE_CPU_JIT_AUX_FUNCTIONS(gemm_bf16_convolution_fwd_t::pp_kernel);
pp_ker_t(const pd_t *pd);

~pp_ker_t() {
for (auto inj : jit_eltwise_injectors_)
delete inj;
jit_eltwise_injectors_.clear();
}

void operator()(dst_data_t *dst, const acc_data_t *acc,
const acc_data_t *bias, float sum_scale, size_t oc_work,
const acc_data_t *bias, float sum_scale, size_t oc_work, size_t g_offset,
const void *post_ops_binary_rhs_arg_vec, const void *dst_orig,
const size_t g_oc_offset);
void operator()(dst_data_t *dst, const acc_data_t *acc,
const acc_data_t *bias, float sum_scale, size_t dst_str,
const acc_data_t *bias,
size_t g_offset, size_t start_oc, float sum_scale, size_t dst_str,
size_t acc_str, size_t sp_len, size_t oc,
const void *post_ops_binary_rhs_arg_vec, const void *dst_orig,
const size_t g_oc_offset);
@@ -155,6 +176,7 @@ struct gemm_bf16_convolution_fwd_t : public primitive_t {
size_t acc_stride_in_bytes;
size_t spatial_length;
size_t oc_work;
size_t oc_offset;

size_t g_oc_offset;
const void *post_ops_binary_rhs_arg_vec;
@@ -179,6 +201,11 @@ struct gemm_bf16_convolution_fwd_t : public primitive_t {
Xbyak::Reg64 reg_dst_str = r13;
Xbyak::Reg64 reg_acc_str = r14;

using Vmm = typename cpu_isa_traits<avx512_common>::Vmm;
Xbyak::Reg64 reg_oc_offset = r10;
Xbyak::Reg64 reg_dw = r9;
Xbyak::Opmask kmask = k7;

Xbyak::Reg64 reserved_eltwise_gpr = r10;
Xbyak::Opmask reserved_eltwise_maskr = k2;

@@ -196,14 +223,15 @@ struct gemm_bf16_convolution_fwd_t : public primitive_t {
constexpr static int stack_space_needed = reg64_size;

const conv_gemm_conf_t &jcp_;
post_ops_t post_ops_;
const bool do_sum_;
int max_data_reg_idx_, max_unroll_, compute_reg_step_;
int data_reg_base_idx_;
size_t vlen_;
cpu_isa_t isa_;
std::unique_ptr<bf16_emulation_t> bf16_emu_;
std::unique_ptr<injector::jit_uni_postops_injector_t<avx512_core>>
postops_injector_;
const primitive_attr_t* attr_;
nstl::vector<jit_uni_eltwise_injector_f32<avx512_common>*> jit_eltwise_injectors_;

void apply_postops(const bool apply_mask, const size_t out_offset,
const int vmm_idx);
359 changes: 359 additions & 0 deletions src/cpu/x64/jit_gemm_convolution_utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,359 @@
/*******************************************************************************
* Copyright 2020-2021 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#include "cpu/x64/jit_generator.hpp"
#include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp"
#include "cpu/x64/injectors/jit_uni_depthwise_injector.hpp"

#include "cpu/x64/jit_gemm_convolution_utils.hpp"

namespace dnnl {
namespace impl {
namespace cpu {
namespace x64 {
namespace gemm_convolution_utils {

using namespace dnnl::impl::cpu::gemm_convolution_utils;

template <cpu_isa_t isa>
struct jit_pp_kernel_t : pp_kernel_t, public jit_generator {
DECLARE_CPU_JIT_AUX_FUNCTIONS(
gemm_convolution_utils::jit_pp_kernel_t);

jit_pp_kernel_t(const convolution_pd_t *pd, const conv_gemm_conf_t &jcp)
: pp_kernel_t(pd, jcp), idx_compute_vreg_start_(0), idx_compute_vreg_max_(isa == avx512_common ? 31 : 15) {
if (utils::one_of(isa, avx2, sse41)) {
idx_compute_vreg_start_ += 1; // Vmm(0) - for masks
}

bool only_eltwise = true;
for (int i = 0; i < post_ops_.len(); i++) {
auto &post_op = post_ops_.entry_[i];
if (post_op.is_eltwise()) {
jit_eltwise_injectors_.push_back(new jit_uni_eltwise_injector_f32<isa>(
this, post_op.eltwise, true, eltwise_reserved_1_, eltwise_reserved_2_));
} else if (post_op.is_depthwise()) {
only_eltwise = false;
jit_depthwise_injectors_.push_back(new jit_uni_depthwise_injector_f32<isa>(
this, post_op.depthwise.alg, depthwise_reserved_2_));
} else {
only_eltwise = false;
}
}
if (post_ops_.len() > 0 && !only_eltwise) {
vreg_d_weights = Vmm(idx_compute_vreg_max_--);
vreg_d_bias = Vmm(idx_compute_vreg_max_--);
}
if (utils::one_of(isa, avx2, sse41))
vreg_zero = Vmm(idx_compute_vreg_start_++);
}
~jit_pp_kernel_t() {
for (auto inj : jit_eltwise_injectors_)
delete inj;
jit_eltwise_injectors_.clear();
for (auto inj : jit_depthwise_injectors_)
delete inj;
jit_depthwise_injectors_.clear();
}

status_t create_kernel() override { return jit_generator::create_kernel(); }

void operator()(float *dst, const float *bias, const int len, const int oc_start, const int oc_work, const int oc_stride) const override {
for (int oc = 0; oc < oc_work; oc++) {
ker_args_t args;
args.dst = dst + oc * oc_stride;
args.bias = bias + oc_start + oc;
args.len = len;
args.oc_offset = oc_start + oc;
jit_generator::operator()(&args);
}
}

private:
void generate() override;

struct ker_args_t {
float *dst;
const float *bias;
size_t len;
size_t oc_offset;
};

nstl::vector<jit_uni_eltwise_injector_f32<isa> *> jit_eltwise_injectors_;
nstl::vector<jit_uni_depthwise_injector_f32<isa> *> jit_depthwise_injectors_;

using Vmm = typename cpu_isa_traits<isa>::Vmm;
static const size_t vlen = cpu_isa_traits<isa>::vlen / sizeof(float);

Xbyak::Reg64 reg_param = abi_param1;
Xbyak::Reg64 reg_dst = rdx;
Xbyak::Reg64 reg_bias = rbx;

Xbyak::Reg64 reg_len = r8;
Xbyak::Reg64 reg_tmp = rcx; // intentional for shifting purposes
Xbyak::Reg64 reg_oc_offset = r9;
Xbyak::Reg64 reg_rem_mask = r10;
Xbyak::Opmask kreg_rem_mask = k1;

// sse41/avx2
Xbyak::Reg64 reg_ptr_maskmovdqu_dst = rdi; // sse41: store destination - must be rdi
Xbyak::Label l_table;
Xbyak::Reg64 reg_table = r12;
Xbyak::Reg64 reg_shift_table = r13;
Vmm vreg_mask = Vmm(0); // sse41: mask for blendvps must be in xmm0
Vmm vreg_zero;

// post_ops
Xbyak::Reg64 eltwise_reserved_1_ = r11;
Xbyak::Opmask eltwise_reserved_2_ = k2;
Xbyak::Opmask depthwise_reserved_2_ = k2;
Xbyak::Reg64 reg_d_weights = r14;
Xbyak::Reg64 reg_d_bias = r15;
Vmm vreg_d_weights, vreg_d_bias;

int idx_compute_vreg_start_;
int idx_compute_vreg_max_;

int idx_vreg_dst(int iter) {
int idx = idx_compute_vreg_start_ + 0;
assert(idx <= idx_compute_vreg_max_);
return idx;
}
int idx_vreg_bias(int iter) {
int idx = idx_compute_vreg_start_ + 1;
assert(idx <= idx_compute_vreg_max_);
return idx;
}

Vmm vreg_dst(int idx) { return Vmm(idx_vreg_dst(idx)); };
Vmm vreg_bias(int idx) { return Vmm(idx_vreg_bias(idx)); };
};

template <cpu_isa_t isa>
void jit_pp_kernel_t<isa>::generate() {
using namespace Xbyak;
using namespace utils;

preamble();

#define PARAM_OFF(x) offsetof(ker_args_t, x)
mov(reg_dst, ptr[reg_param + PARAM_OFF(dst)]);
mov(reg_bias, ptr[reg_param + PARAM_OFF(bias)]);
mov(reg_len, ptr[reg_param + PARAM_OFF(len)]);
mov(reg_oc_offset, ptr[reg_param + PARAM_OFF(oc_offset)]);
#undef PARAM_OFF

if (utils::one_of(isa, avx2, sse41)) {
uni_vpxor(vreg_zero, vreg_zero, vreg_zero);
mov(reg_table, l_table);
}

auto apply_post_ops = [&]() {
int eltwise_inj_idx = 0;
int depthwise_inj_idx = 0;
auto vreg_dst_ = vreg_dst(0);
for (int i = 0; i < post_ops_.len(); i++) {
auto &post_op = post_ops_.entry_[i];
// todo: antonvor: sum?
if (post_op.is_eltwise()) {
jit_eltwise_injectors_[eltwise_inj_idx]->compute_vector(vreg_dst_.getIdx());
eltwise_inj_idx++;
} else if (post_op.is_depthwise()) {
mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
lea(reg_d_weights, ptr[reg_d_weights + reg_oc_offset * sizeof(float)]);
lea(reg_d_bias, ptr[reg_d_bias + reg_oc_offset * sizeof(float)]);
jit_depthwise_injectors_[depthwise_inj_idx]->compute_vector_range(vreg_dst_.getIdx(), vreg_dst_.getIdx() + 1,
reg_d_weights, reg_d_bias, true);
depthwise_inj_idx++;
} else if (post_op.is_quantization()) {
bool do_dequantization = post_op.quantization.alg == alg_kind::quantization_quantize_dequantize;
bool do_rounding = true;

if (post_op.quantization.crop_low_data->count_ != 1) {
mov(reg_d_weights, reinterpret_cast<size_t>(post_op.quantization.crop_low_data->shifts_));
uni_vpbroadcastd(vreg_d_weights, ptr[reg_d_weights + reg_oc_offset * sizeof(float)]);
} else {
mov(reg_d_weights, reinterpret_cast<size_t>(post_op.quantization.crop_low_data->shifts_));
uni_vpbroadcastd(vreg_d_weights, ptr[reg_d_weights]);
}

if (post_op.quantization.crop_high_data->count_ != 1) {
mov(reg_d_bias, reinterpret_cast<size_t>(post_op.quantization.crop_high_data->shifts_));
uni_vpbroadcastd(vreg_d_bias, ptr[reg_d_bias + reg_oc_offset * sizeof(float)]);
} else {
mov(reg_d_bias, reinterpret_cast<size_t>(post_op.quantization.crop_high_data->shifts_));
uni_vpbroadcastd(vreg_d_bias, ptr[reg_d_bias]);
}

uni_vmaxps(vreg_dst_, vreg_dst_, vreg_d_weights);
uni_vminps(vreg_dst_, vreg_dst_, vreg_d_bias);

if (post_op.quantization.input_scale_data->count_ != 1) {
mov(reg_d_weights, reinterpret_cast<size_t>(post_op.quantization.input_scale_data->scales_));
uni_vpbroadcastd(vreg_d_weights, ptr[reg_d_weights + reg_oc_offset * sizeof(float)]);
} else {
mov(reg_d_weights, reinterpret_cast<size_t>(post_op.quantization.input_scale_data->scales_));
uni_vpbroadcastd(vreg_d_weights, ptr[reg_d_weights]);
}

if (post_op.quantization.input_shift_data->count_ != 1) {
mov(reg_d_bias, reinterpret_cast<size_t>(post_op.quantization.input_shift_data->shifts_));
uni_vpbroadcastd(vreg_d_bias, ptr[reg_d_bias + reg_oc_offset * sizeof(float)]);
} else {
mov(reg_d_bias, reinterpret_cast<size_t>(post_op.quantization.input_shift_data->shifts_));
uni_vpbroadcastd(vreg_d_bias, ptr[reg_d_bias]);
}

uni_vfmadd213ps(vreg_dst_, vreg_d_weights, vreg_d_bias);

if (do_rounding)
uni_vroundps(vreg_dst_, vreg_dst_, 0);

if (do_dequantization) {
if (post_op.quantization.output_scale_data->count_ != 1) {
mov(reg_d_weights, reinterpret_cast<size_t>(post_op.quantization.output_scale_data->scales_));
uni_vpbroadcastd(vreg_d_weights, ptr[reg_d_weights + reg_oc_offset * sizeof(float)]);
} else {
mov(reg_d_weights, reinterpret_cast<size_t>(post_op.quantization.output_scale_data->scales_));
uni_vpbroadcastd(vreg_d_weights, ptr[reg_d_weights]);
}

if (post_op.quantization.output_shift_data->count_ != 1) {
mov(reg_d_bias, reinterpret_cast<size_t>(post_op.quantization.output_shift_data->shifts_));
uni_vpbroadcastd(vreg_d_bias, ptr[reg_d_bias + reg_oc_offset * sizeof(float)]);
} else {
mov(reg_d_bias, reinterpret_cast<size_t>(post_op.quantization.output_shift_data->shifts_));
uni_vpbroadcastd(vreg_d_bias, ptr[reg_d_bias]);
}

uni_vfmadd213ps(vreg_dst_, vreg_d_weights, vreg_d_bias);
}
}
}
};

// Load accumulated value, convert to float, apply bias (if any), scaling,
// and eltwise (if any); then convert to destination type and store
auto compute = [&](bool apply_mask) {
auto dst_addr = ptr[reg_dst];
auto vreg_dst_ = vreg_dst(0);
if (isa == avx512_common) {
if (apply_mask)
vreg_dst_ = vreg_dst_ | kreg_rem_mask;
uni_vmovups(vreg_dst_, dst_addr);
} else {
if (apply_mask) {
if (isa != sse41) {
uni_vblendvps(vreg_dst_, vreg_zero, dst_addr, vreg_mask);
} else {
uni_vmovups(vreg_dst_, dst_addr);
}
} else {
uni_vmovups(vreg_dst_, dst_addr);
}
}

if (do_bias_) {
auto vreg_bias_ = vreg_bias(0);
if (isa == avx512_common && apply_mask)
vreg_bias_ = vreg_bias_ | kreg_rem_mask;

uni_vpbroadcastd(vreg_bias_, ptr[reg_bias]);
uni_vaddps(vreg_dst_, vreg_dst_, vreg_bias_);
}

apply_post_ops();

if (isa == avx512_common) {
uni_vmovups(dst_addr, vreg_dst_);
} else {
if (apply_mask) {
if (isa != sse41) {
vmaskmovps(dst_addr, vreg_mask, vreg_dst_);
} else {
lea(reg_ptr_maskmovdqu_dst, dst_addr);
maskmovdqu(vreg_dst_, vreg_mask);
}
} else {
uni_vmovups(dst_addr, vreg_dst_);
}
}
};

Label loop_end;
{
cmp(reg_len, 0);
je(loop_end, T_NEAR);

Label loop, loop_tail;
cmp(reg_len, vlen);
jl(loop_tail, T_NEAR);
L(loop); {
compute(false);
sub(reg_len, vlen);
add(reg_dst, vlen * sizeof(float));
cmp(reg_len, vlen);
jge(loop, T_NEAR);
}

L(loop_tail);
mov(reg_tmp, reg_len); // reg_tmp is rcx, and we need cl for the shift
if (isa == avx512_common) {
mov(reg_rem_mask, 1);
shl(reg_rem_mask, cl); // reg_tmp == rcx and reg_tail < vlen == 16
sub(reg_rem_mask, 1);
jz(loop_end, T_NEAR);
kmovq(kreg_rem_mask, reg_rem_mask);
} else {
mov(reg_shift_table, vlen);
sub(reg_shift_table, reg_tmp);
uni_vmovups(vreg_mask, ptr[reg_table + reg_shift_table * sizeof(float)]);
}
compute(true);
}
L(loop_end);

postamble();

for (auto& inj : jit_eltwise_injectors_)
inj->prepare_table();

if (utils::one_of(isa, avx2, sse41)) {
align(64);
L(l_table);
for (size_t i = 0; i < vlen; i++) dd(0xFFFFFFFF);
for (size_t i = 0; i < vlen; i++) dd(0x00000000);
}
}

pp_kernel_t *jit_pp_kernel_create(
const convolution_pd_t *pd, const conv_gemm_conf_t &jcp) {
if (mayiuse(avx512_common)) {
return new jit_pp_kernel_t<avx512_common>(pd, jcp);
} else if (mayiuse(avx2)) {
return new jit_pp_kernel_t<avx2>(pd, jcp);
} else if (mayiuse(sse41)) {
return new jit_pp_kernel_t<sse41>(pd, jcp);
}
return nullptr;
}

} // namespace gemm_convolution_utils
} // namespace x64
} // namespace cpu
} // namespace impl
} // namespace dnnl
36 changes: 36 additions & 0 deletions src/cpu/x64/jit_gemm_convolution_utils.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*******************************************************************************
* Copyright 2020-2021 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#ifndef CPU_X64_JIT_GEMM_CONVOLUTION_UTILS_HPP
#define CPU_X64_JIT_GEMM_CONVOLUTION_UTILS_HPP

#include "cpu/gemm_convolution_utils.hpp"

namespace dnnl {
namespace impl {
namespace cpu {
namespace x64 {
namespace gemm_convolution_utils {

cpu::gemm_convolution_utils::pp_kernel_t *jit_pp_kernel_create(
const convolution_pd_t *pd, const conv_gemm_conf_t &jcp);
} // namespace gemm_convolution_utils
} // namespace x64
} // namespace cpu
} // namespace impl
} // namespace dnnl

#endif
1,117 changes: 569 additions & 548 deletions src/cpu/x64/jit_gemm_x8s8s32x_convolution_utils.cpp

Large diffs are not rendered by default.

4 changes: 1 addition & 3 deletions src/cpu/x64/jit_gemm_x8s8s32x_convolution_utils.hpp
Original file line number Diff line number Diff line change
@@ -28,10 +28,8 @@ namespace gemm_x8s8s32x_convolution_utils {

cpu::gemm_x8s8s32x_convolution_utils::pp_ker_t *jit_pp_ker_create(
const convolution_pd_t *pd, const conv_gemm_conf_t &jcp);
bool mayiuse_jit_pp_kernel() noexcept;
bool post_ops_ok(const post_ops_t &post_ops, const memory_desc_wrapper *dst_d);

} // namespace gemm_x8s8s32x_convolution_utils
} // namespace gemm_x8s8s32x_convolutilon_utils
} // namespace x64
} // namespace cpu
} // namespace impl