@@ -82,7 +82,8 @@ void brgemm_desc_t::cleanup_dst_md() {
82
82
void brgemm_kernel_execute (const brgemm_kernel_t *brg_kernel, int bs,
83
83
const brgemm_batch_element_t *batch, void *ptr_C, void *scratch,
84
84
const brgemm_dynamic_values_t *dynamic_values,
85
- const void *ptr_wei_scales, const void *ptr_wei_zero_points, const void *ptr_src_scales, size_t ic) {
85
+ const void *ptr_wei_scales, const void *ptr_wei_zero_points,
86
+ const void *ptr_src_scales, const void *ptr_src_grouped_sum, size_t ic) {
86
87
brgemm_kernel_params_t brgemm_p;
87
88
88
89
brgemm_p.batch = batch;
@@ -105,6 +106,7 @@ void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs,
105
106
brgemm_p.ptr_wei_scales = ptr_wei_scales;
106
107
brgemm_p.ptr_wei_zero_points = ptr_wei_zero_points;
107
108
brgemm_p.ptr_src_scales = ptr_src_scales;
109
+ brgemm_p.ptr_src_grouped_sum = ptr_src_grouped_sum;
108
110
brgemm_p.ic = ic;
109
111
110
112
assert (brg_kernel);
@@ -116,7 +118,8 @@ void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs,
116
118
const void *addr_A, const void *addr_B,
117
119
const brgemm_batch_element_t *batch, void *ptr_C, void *scratch,
118
120
const brgemm_dynamic_values_t *dynamic_values,
119
- const void *ptr_wei_scales, const void *ptr_wei_zero_points, const void *ptr_src_scales, size_t ic) {
121
+ const void *ptr_wei_scales, const void *ptr_wei_zero_points,
122
+ const void *ptr_src_scales, const void *ptr_src_grouped_sum, size_t ic) {
120
123
brgemm_kernel_params_t brgemm_p;
121
124
122
125
brgemm_p.batch = batch;
@@ -133,6 +136,7 @@ void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs,
133
136
brgemm_p.ptr_wei_scales = ptr_wei_scales;
134
137
brgemm_p.ptr_wei_zero_points = ptr_wei_zero_points;
135
138
brgemm_p.ptr_src_scales = ptr_src_scales;
139
+ brgemm_p.ptr_src_grouped_sum = ptr_src_grouped_sum;
136
140
brgemm_p.ic = ic;
137
141
if (dynamic_values) {
138
142
brgemm_p.dynamic_LDA = dynamic_values->dynamic_LDA ;
@@ -148,7 +152,8 @@ void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs,
148
152
const brgemm_batch_element_t *batch, void *ptr_C, void *ptr_D,
149
153
const brgemm_post_ops_data_t &post_ops_data, void *scratch,
150
154
const brgemm_dynamic_values_t *dynamic_values,
151
- const void *ptr_wei_scales, const void *ptr_wei_zero_points, const void *ptr_src_scales, size_t ic) {
155
+ const void *ptr_wei_scales, const void *ptr_wei_zero_points,
156
+ const void *ptr_src_scales, const void *ptr_src_grouped_sum, size_t ic) {
152
157
brgemm_kernel_params_t brgemm_p;
153
158
154
159
brgemm_p.batch = batch;
@@ -178,6 +183,7 @@ void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs,
178
183
brgemm_p.ptr_wei_scales = ptr_wei_scales;
179
184
brgemm_p.ptr_wei_zero_points = ptr_wei_zero_points;
180
185
brgemm_p.ptr_src_scales = ptr_src_scales;
186
+ brgemm_p.ptr_src_grouped_sum = ptr_src_grouped_sum;
181
187
brgemm_p.ic = ic;
182
188
if (dynamic_values) {
183
189
brgemm_p.dynamic_LDA = dynamic_values->dynamic_LDA ;
@@ -194,7 +200,8 @@ void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs,
194
200
const brgemm_batch_element_t *batch, void *ptr_C, void *ptr_D,
195
201
const brgemm_post_ops_data_t &post_ops_data, void *scratch,
196
202
const brgemm_dynamic_values_t *dynamic_values,
197
- const void *ptr_wei_scales, const void *ptr_wei_zero_points, const void *ptr_src_scales, size_t ic) {
203
+ const void *ptr_wei_scales, const void *ptr_wei_zero_points,
204
+ const void *ptr_src_scales, const void *ptr_src_grouped_sum, size_t ic) {
198
205
brgemm_kernel_params_t brgemm_p;
199
206
200
207
brgemm_p.batch = batch;
@@ -224,6 +231,7 @@ void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs,
224
231
brgemm_p.ptr_wei_scales = ptr_wei_scales;
225
232
brgemm_p.ptr_wei_zero_points = ptr_wei_zero_points;
226
233
brgemm_p.ptr_src_scales = ptr_src_scales;
234
+ brgemm_p.ptr_src_grouped_sum = ptr_src_grouped_sum;
227
235
brgemm_p.ic = ic;
228
236
if (dynamic_values) {
229
237
brgemm_p.dynamic_LDA = dynamic_values->dynamic_LDA ;
@@ -318,6 +326,12 @@ status_t brgemm_desc_init(brgemm_desc_t *brg, cpu_isa_t isa,
318
326
319
327
CHECK (brgemm_blocking (brg));
320
328
329
+ brg->src_sum_group_size = wei_d.dims ()[1 ];
330
+ if (brg->with_src_dyn_quant ) {
331
+ brg->src_sum_group_size = brg->rd_block ;
332
+ brg->src_grouped_sum_stride = div_up (wei_d.dims ()[1 ], brg->src_sum_group_size );
333
+ }
334
+
321
335
// avx2_vnni_2 kernel with xf16 data type requires blocked weights.
322
336
if (brg->isa_impl == avx2_vnni_2 && brg->is_xf16 ()
323
337
&& brg->LDB % brg->ld_block > 0 )
0 commit comments