@@ -272,11 +272,6 @@ status_t jit_uni_pool_kernel<isa>::init_conf(jit_pool_conf_t &jpp,
272
272
jpp.ur_bc = 1 ;
273
273
jpp.ur_bc_tail = 0 ;
274
274
}
275
- auto ur_w = nstl::min (jpp.ow , jpp.ur / jpp.ur_bc );
276
- if (utils::div_up (jpp.l_pad , jpp.stride_w ) > ur_w)
277
- return status::unimplemented;
278
- if (utils::div_up (right_pad, jpp.stride_w ) > ur_w)
279
- return status::unimplemented;
280
275
281
276
// scratchpad for c_block slice of input and/or output
282
277
using namespace memory_tracking ::names;
@@ -1301,7 +1296,8 @@ void jit_uni_pool_kernel<isa>::generate() {
1301
1296
1302
1297
auto dt_size = jpp.dt_size ;
1303
1298
auto shift = (isa == sse41) ? vlen : 0 ;
1304
- add (reg_input, dt_size * (ur_w * stride_w - lpad) * c_off - shift);
1299
+ add (reg_input,
1300
+ dt_size * nstl::max (0 , ur_w * stride_w - lpad) * c_off - shift);
1305
1301
add (reg_output, dt_size * ur_w * c_off - shift);
1306
1302
if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward )) {
1307
1303
auto ishift = (isa == sse41) ? jpp.c_block / 2 : 0 ;
@@ -1344,18 +1340,25 @@ void jit_uni_pool_kernel<isa>::generate() {
1344
1340
auto ur_w = nstl::min (jpp.ow , jpp.ur / jpp.ur_bc );
1345
1341
auto ur_w_tail = jpp.ow % ur_w;
1346
1342
1347
- int n_oi = ow / ur_w;
1343
+ const int n_oi_iterations = ow / ur_w;
1344
+ int n_oi = n_oi_iterations;
1348
1345
1349
- int r_pad1
1346
+ const int r_pad1
1350
1347
= calculate_end_padding (l_pad, ur_w * n_oi, iw, stride_w, kw);
1351
- if (r_pad1 > 0 ) n_oi--;
1348
+ const int ur_stride_w = ur_w * stride_w;
1349
+ const int l_pad_iterations = utils::div_up (l_pad, ur_stride_w);
1350
+ const int r_pad_iterations = utils::div_up (r_pad1, ur_stride_w);
1352
1351
1353
- if (l_pad > 0 ) {
1352
+ n_oi -= nstl::max (0 , r_pad_iterations);
1353
+
1354
+ for (int i = 0 ; i < l_pad_iterations; ++i) {
1354
1355
n_oi--;
1356
+ const int cur_l_pad = l_pad - i * ur_stride_w;
1355
1357
if (n_oi < 0 && r_pad1 > 0 )
1356
- process_oi (ur_w, ur_bc, l_pad, r_pad1, with_c_tail_processing);
1357
- else
1358
- process_oi (ur_w, ur_bc, l_pad, 0 , with_c_tail_processing);
1358
+ process_oi (
1359
+ ur_w, ur_bc, cur_l_pad, r_pad1, with_c_tail_processing);
1360
+ else if (n_oi >= 0 )
1361
+ process_oi (ur_w, ur_bc, cur_l_pad, 0 , with_c_tail_processing);
1359
1362
}
1360
1363
1361
1364
xor_ (oi_iter, oi_iter);
@@ -1371,12 +1374,23 @@ void jit_uni_pool_kernel<isa>::generate() {
1371
1374
}
1372
1375
}
1373
1376
1374
- if (r_pad1 > 0 && n_oi >= 0 )
1375
- process_oi (ur_w, ur_bc, 0 , r_pad1, with_c_tail_processing);
1377
+ if (n_oi >= 0 ) {
1378
+ const int r_pad1_tail = r_pad1 % ur_stride_w != 0
1379
+ ? r_pad1 % ur_stride_w
1380
+ : ur_stride_w;
1381
+ for (int i = 0 ; i < r_pad_iterations; ++i) {
1382
+ const int cur_r_pad = r_pad1_tail + ur_stride_w * i;
1383
+ process_oi (ur_w, ur_bc, 0 , cur_r_pad, with_c_tail_processing);
1384
+ }
1385
+ }
1376
1386
1377
- if (ur_w_tail != 0 )
1378
- process_oi (
1379
- ur_w_tail, ur_bc, 0 , r_pad, with_c_tail_processing, false );
1387
+ if (ur_w_tail != 0 ) {
1388
+ const int l_pad_tail = n_oi_iterations < l_pad_iterations
1389
+ ? l_pad % ur_stride_w
1390
+ : 0 ;
1391
+ process_oi (ur_w_tail, ur_bc, l_pad_tail, r_pad,
1392
+ with_c_tail_processing, false );
1393
+ }
1380
1394
};
1381
1395
Label ur_bc_tail_label, c_tail_processing_label, finish_label;
1382
1396
0 commit comments