Skip to content

Commit b90bb0c

Browse files
bartekxkantonvor
authored andcommitted
cpu: x64: pool: enable optimized pooling for pad > ur_w
1 parent f941509 commit b90bb0c

File tree

2 files changed

+37
-18
lines changed

2 files changed

+37
-18
lines changed

src/cpu/x64/jit_uni_pool_kernel.cpp

+32-18
Original file line numberDiff line numberDiff line change
@@ -272,11 +272,6 @@ status_t jit_uni_pool_kernel<isa>::init_conf(jit_pool_conf_t &jpp,
272272
jpp.ur_bc = 1;
273273
jpp.ur_bc_tail = 0;
274274
}
275-
auto ur_w = nstl::min(jpp.ow, jpp.ur / jpp.ur_bc);
276-
if (utils::div_up(jpp.l_pad, jpp.stride_w) > ur_w)
277-
return status::unimplemented;
278-
if (utils::div_up(right_pad, jpp.stride_w) > ur_w)
279-
return status::unimplemented;
280275

281276
// scratchpad for c_block slice of input and/or output
282277
using namespace memory_tracking::names;
@@ -1301,7 +1296,8 @@ void jit_uni_pool_kernel<isa>::generate() {
13011296

13021297
auto dt_size = jpp.dt_size;
13031298
auto shift = (isa == sse41) ? vlen : 0;
1304-
add(reg_input, dt_size * (ur_w * stride_w - lpad) * c_off - shift);
1299+
add(reg_input,
1300+
dt_size * nstl::max(0, ur_w * stride_w - lpad) * c_off - shift);
13051301
add(reg_output, dt_size * ur_w * c_off - shift);
13061302
if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward)) {
13071303
auto ishift = (isa == sse41) ? jpp.c_block / 2 : 0;
@@ -1344,18 +1340,25 @@ void jit_uni_pool_kernel<isa>::generate() {
13441340
auto ur_w = nstl::min(jpp.ow, jpp.ur / jpp.ur_bc);
13451341
auto ur_w_tail = jpp.ow % ur_w;
13461342

1347-
int n_oi = ow / ur_w;
1343+
const int n_oi_iterations = ow / ur_w;
1344+
int n_oi = n_oi_iterations;
13481345

1349-
int r_pad1
1346+
const int r_pad1
13501347
= calculate_end_padding(l_pad, ur_w * n_oi, iw, stride_w, kw);
1351-
if (r_pad1 > 0) n_oi--;
1348+
const int ur_stride_w = ur_w * stride_w;
1349+
const int l_pad_iterations = utils::div_up(l_pad, ur_stride_w);
1350+
const int r_pad_iterations = utils::div_up(r_pad1, ur_stride_w);
13521351

1353-
if (l_pad > 0) {
1352+
n_oi -= nstl::max(0, r_pad_iterations);
1353+
1354+
for (int i = 0; i < l_pad_iterations; ++i) {
13541355
n_oi--;
1356+
const int cur_l_pad = l_pad - i * ur_stride_w;
13551357
if (n_oi < 0 && r_pad1 > 0)
1356-
process_oi(ur_w, ur_bc, l_pad, r_pad1, with_c_tail_processing);
1357-
else
1358-
process_oi(ur_w, ur_bc, l_pad, 0, with_c_tail_processing);
1358+
process_oi(
1359+
ur_w, ur_bc, cur_l_pad, r_pad1, with_c_tail_processing);
1360+
else if (n_oi >= 0)
1361+
process_oi(ur_w, ur_bc, cur_l_pad, 0, with_c_tail_processing);
13591362
}
13601363

13611364
xor_(oi_iter, oi_iter);
@@ -1371,12 +1374,23 @@ void jit_uni_pool_kernel<isa>::generate() {
13711374
}
13721375
}
13731376

1374-
if (r_pad1 > 0 && n_oi >= 0)
1375-
process_oi(ur_w, ur_bc, 0, r_pad1, with_c_tail_processing);
1377+
if (n_oi >= 0) {
1378+
const int r_pad1_tail = r_pad1 % ur_stride_w != 0
1379+
? r_pad1 % ur_stride_w
1380+
: ur_stride_w;
1381+
for (int i = 0; i < r_pad_iterations; ++i) {
1382+
const int cur_r_pad = r_pad1_tail + ur_stride_w * i;
1383+
process_oi(ur_w, ur_bc, 0, cur_r_pad, with_c_tail_processing);
1384+
}
1385+
}
13761386

1377-
if (ur_w_tail != 0)
1378-
process_oi(
1379-
ur_w_tail, ur_bc, 0, r_pad, with_c_tail_processing, false);
1387+
if (ur_w_tail != 0) {
1388+
const int l_pad_tail = n_oi_iterations < l_pad_iterations
1389+
? l_pad % ur_stride_w
1390+
: 0;
1391+
process_oi(ur_w_tail, ur_bc, l_pad_tail, r_pad,
1392+
with_c_tail_processing, false);
1393+
}
13801394
};
13811395
Label ur_bc_tail_label, c_tail_processing_label, finish_label;
13821396

tests/benchdnn/inputs/pool/shapes_2d

+5
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,11 @@ ic35_iw30ih37_ow14oh17_kw3kh4_sw2sh2
106106
ic35_iw30ih36_ow14oh17_pw0ph1_kw3kh4_sw2sh2
107107
ic35_iw33ih37_ow14oh17_kw6kh4_sw2sh2
108108
ic35_iw33ih36_ow14oh17_pw0ph1_kw6kh4_sw2sh2
109+
110+
# Padding is bigger than ur_w
111+
mb1ic8_ih19oh10kh15dh0sh2ph14_iw19ow10kw15dw0sw2pw14
112+
mb1ic8_ih19oh10kh14dh0sh2ph13_iw19ow10kw14dw0sw2pw13
113+
109114
# With dilation
110115
mb1ic8_ih3oh3_kh3ph1_dh2dw2
111116
mb122ic32_ih32iw2_oh32ow2_kh3kw3_ph1pw1_dh4dw1

0 commit comments

Comments
 (0)