@@ -71,24 +71,24 @@ float fwd_Xnary(unsigned kind, unsigned algorithm, float x, float y,
71
71
}
72
72
73
73
#define FMA_BLOCK ( \
74
- block_size , nof_elems , acc_ptr , acc_elem_dt , a_ptr , a_elem_dt , b ) \
74
+ block_size , nof_elems , acc_ptr , acc_elem_dt , a_ptr , a_elem_dt , b , c ) \
75
75
unroll_for(; nof_elems >= block_size; acc_ptr += block_size, \
76
76
a_ptr += block_size, nof_elems -= block_size) { \
77
77
CONCAT2(acc_elem_dt, block_size) \
78
78
a_conv = CONCAT3(convert_, acc_elem_dt, block_size)( \
79
79
*((CONCAT2(a_elem_dt, block_size) *)a_ptr)); \
80
- *((CONCAT2(acc_elem_dt, block_size) *)acc_ptr) = fma( \
81
- a_conv, b, *((CONCAT2(acc_elem_dt, block_size) *)acc_ptr)); \
80
+ *((CONCAT2(acc_elem_dt, block_size) *)acc_ptr) = fma(a_conv - c, b, \
81
+ *((CONCAT2(acc_elem_dt, block_size) *)acc_ptr)); \
82
82
}
83
83
84
- #define FMA_MIXED (acc_nof_elems , a , a_elem_dt , b , acc_ptr , acc_elem_dt ) \
84
+ #define FMA_MIXED (acc_nof_elems , a , a_elem_dt , b , acc_ptr , acc_elem_dt , c ) \
85
85
{ \
86
86
auto nof_elems = acc_nof_elems; \
87
87
a_elem_dt *a_ptr = (a_elem_dt *)(&a); \
88
- FMA_BLOCK(8, nof_elems, acc_ptr, acc_elem_dt, a_ptr, a_elem_dt, b); \
89
- FMA_BLOCK(4, nof_elems, acc_ptr, acc_elem_dt, a_ptr, a_elem_dt, b); \
90
- FMA_BLOCK(2, nof_elems, acc_ptr, acc_elem_dt, a_ptr, a_elem_dt, b); \
91
- if (nof_elems == 1) { *acc_ptr += (*a_ptr) * b; } \
88
+ FMA_BLOCK(8, nof_elems, acc_ptr, acc_elem_dt, a_ptr, a_elem_dt, b, c ); \
89
+ FMA_BLOCK(4, nof_elems, acc_ptr, acc_elem_dt, a_ptr, a_elem_dt, b, c ); \
90
+ FMA_BLOCK(2, nof_elems, acc_ptr, acc_elem_dt, a_ptr, a_elem_dt, b, c ); \
91
+ if (nof_elems == 1) { *acc_ptr += (*a_ptr - c ) * b; } \
92
92
}
93
93
94
94
#define po_dt (idx ) CONCAT3(PO_, idx, _BIN_ARG_ACTUAL_DATA_T)
@@ -227,7 +227,7 @@ float fwd_Xnary(unsigned kind, unsigned algorithm, float x, float y,
227
227
#define APPLY_PO_SUM ( \
228
228
idx , accumulator , acc_size , acc_elem_dt , sum_src , sum_elem_dt ) \
229
229
FMA_MIXED(acc_size, sum_src, sum_elem_dt, CONCAT3(PO_, idx, _SUM_SCALE), \
230
- accumulator, acc_elem_dt);
230
+ accumulator, acc_elem_dt, CONCAT3(PO_, idx, _SUM_ZP) );
231
231
232
232
#define APPLY_PO_ELTWISE (idx , accumulator , nelems ) \
233
233
FWD_XNARY_GENERIC_DT(PO_ELTWISE, CONCAT3(PO_, idx, _ALG), accumulator, \
0 commit comments