Skip to content

Commit c5fedda

Browse files
committed
xe: ocl: add support for zp in sum post op
1 parent 42d094d commit c5fedda

File tree

2 files changed

+18
-13
lines changed

2 files changed

+18
-13
lines changed

src/gpu/intel/ocl/ocl_post_ops.h

+9-9
Original file line numberDiff line numberDiff line change
@@ -71,24 +71,24 @@ float fwd_Xnary(unsigned kind, unsigned algorithm, float x, float y,
7171
}
7272

7373
#define FMA_BLOCK( \
74-
block_size, nof_elems, acc_ptr, acc_elem_dt, a_ptr, a_elem_dt, b) \
74+
block_size, nof_elems, acc_ptr, acc_elem_dt, a_ptr, a_elem_dt, b, c) \
7575
unroll_for(; nof_elems >= block_size; acc_ptr += block_size, \
7676
a_ptr += block_size, nof_elems -= block_size) { \
7777
CONCAT2(acc_elem_dt, block_size) \
7878
a_conv = CONCAT3(convert_, acc_elem_dt, block_size)( \
7979
*((CONCAT2(a_elem_dt, block_size) *)a_ptr)); \
80-
*((CONCAT2(acc_elem_dt, block_size) *)acc_ptr) = fma( \
81-
a_conv, b, *((CONCAT2(acc_elem_dt, block_size) *)acc_ptr)); \
80+
*((CONCAT2(acc_elem_dt, block_size) *)acc_ptr) = fma(a_conv - c, b, \
81+
*((CONCAT2(acc_elem_dt, block_size) *)acc_ptr)); \
8282
}
8383

84-
#define FMA_MIXED(acc_nof_elems, a, a_elem_dt, b, acc_ptr, acc_elem_dt) \
84+
#define FMA_MIXED(acc_nof_elems, a, a_elem_dt, b, acc_ptr, acc_elem_dt, c) \
8585
{ \
8686
auto nof_elems = acc_nof_elems; \
8787
a_elem_dt *a_ptr = (a_elem_dt *)(&a); \
88-
FMA_BLOCK(8, nof_elems, acc_ptr, acc_elem_dt, a_ptr, a_elem_dt, b); \
89-
FMA_BLOCK(4, nof_elems, acc_ptr, acc_elem_dt, a_ptr, a_elem_dt, b); \
90-
FMA_BLOCK(2, nof_elems, acc_ptr, acc_elem_dt, a_ptr, a_elem_dt, b); \
91-
if (nof_elems == 1) { *acc_ptr += (*a_ptr) * b; } \
88+
FMA_BLOCK(8, nof_elems, acc_ptr, acc_elem_dt, a_ptr, a_elem_dt, b, c); \
89+
FMA_BLOCK(4, nof_elems, acc_ptr, acc_elem_dt, a_ptr, a_elem_dt, b, c); \
90+
FMA_BLOCK(2, nof_elems, acc_ptr, acc_elem_dt, a_ptr, a_elem_dt, b, c); \
91+
if (nof_elems == 1) { *acc_ptr += (*a_ptr - c) * b; } \
9292
}
9393

9494
#define po_dt(idx) CONCAT3(PO_, idx, _BIN_ARG_ACTUAL_DATA_T)
@@ -227,7 +227,7 @@ float fwd_Xnary(unsigned kind, unsigned algorithm, float x, float y,
227227
#define APPLY_PO_SUM( \
228228
idx, accumulator, acc_size, acc_elem_dt, sum_src, sum_elem_dt) \
229229
FMA_MIXED(acc_size, sum_src, sum_elem_dt, CONCAT3(PO_, idx, _SUM_SCALE), \
230-
accumulator, acc_elem_dt);
230+
accumulator, acc_elem_dt, CONCAT3(PO_, idx, _SUM_ZP));
231231

232232
#define APPLY_PO_ELTWISE(idx, accumulator, nelems) \
233233
FWD_XNARY_GENERIC_DT(PO_ELTWISE, CONCAT3(PO_, idx, _ALG), accumulator, \

src/gpu/intel/primitive_conf.cpp

+9-4
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ bool post_ops_with_binary_ok(const primitive_attr_t *attr,
544544
const auto &p = attr->post_ops_;
545545

546546
auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(false); };
547-
auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(false); };
547+
auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(false, false); };
548548
auto is_binary = [&](int idx) { return p.entry_[idx].is_binary(); };
549549
auto is_prelu = [&](int idx) { return p.entry_[idx].is_prelu(); };
550550

@@ -564,7 +564,6 @@ bool post_ops_with_binary_ok(const primitive_attr_t *attr,
564564
}
565565
}
566566
if (is_sum(po_idx)) {
567-
if (p.entry_[po_idx].sum.zero_point != 0) return false;
568567
if (p.entry_[po_idx].sum.dt != dnnl_data_type_undef
569568
&& types::data_type_size(p.entry_[po_idx].sum.dt)
570569
!= types::data_type_size(dst_dt))
@@ -693,19 +692,25 @@ status_t def_post_ops_cfg(compute::kernel_ctx_t &kernel_ctx,
693692
("PO_" + std::to_string(idx) + "_ELTWISE_SCALE").c_str(),
694693
1.0f);
695694
}
696-
if (e.is_sum(false)) {
695+
if (e.is_sum(false, false)) {
697696
kernel_ctx.define_int(
698697
"PO_" + std::to_string(idx) + "_KIND", po_sum_id);
699698
kernel_ctx.define_int(
700699
"PO_" + std::to_string(idx) + "_ALG", alg_kind::undef);
701700
kernel_ctx.define_float(
702701
("PO_" + std::to_string(idx) + "_SUM_SCALE").c_str(),
703702
e.sum.scale);
703+
kernel_ctx.define_int(
704+
("PO_" + std::to_string(idx) + "_SUM_ZP").c_str(),
705+
e.sum.zero_point);
706+
704707
} else {
705708
kernel_ctx.define_float(
706709
("PO_" + std::to_string(idx) + "_SUM_SCALE").c_str(), 1.0f);
710+
kernel_ctx.define_int(
711+
("PO_" + std::to_string(idx) + "_SUM_ZP").c_str(), 0);
707712
}
708-
if (!(e.is_binary() || e.is_eltwise(false) || e.is_sum(false)
713+
if (!(e.is_binary() || e.is_eltwise(false) || e.is_sum(false, false)
709714
|| e.is_prelu())) {
710715
// empty post op
711716
kernel_ctx.define_int(

0 commit comments

Comments
 (0)