@@ -35,6 +35,7 @@ namespace x64 {
35
35
using namespace dnnl ::impl::format_tag;
36
36
using namespace dnnl ::impl::prop_kind;
37
37
using namespace dnnl ::impl::utils;
38
+ using namespace dnnl ::impl::memory_tracking::names;
38
39
39
40
using namespace Xbyak ;
40
41
@@ -398,6 +399,7 @@ status_t jit_sse41_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp,
398
399
jcp.mb = src_d.dims ()[0 ];
399
400
400
401
jcp.oc = dst_d.dims ()[1 ] / jcp.ngroups ;
402
+ jcp.oc_without_padding = jcp.oc ;
401
403
jcp.ic = src_d.dims ()[1 ] / jcp.ngroups ;
402
404
403
405
jcp.ih = (ndims == 3 ) ? 1 : src_d.dims ()[2 ];
@@ -491,7 +493,15 @@ status_t jit_sse41_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp,
491
493
VDISPATCH_CONV_IC (channel_pad_ok, VERBOSE_UNSUPPORTED_PAD_FEATURE,
492
494
" i/o and padded channel size mismatch" );
493
495
496
+ bool ok_to_pad_channels = true && jcp.ngroups == 1 ;
497
+
494
498
const int simd_w = 8 ; // 2 SSE vectors processing at once
499
+ if (ok_to_pad_channels) {
500
+ jcp.oc = rnd_up (jcp.oc , simd_w);
501
+ if (mimo) {
502
+ jcp.ic = rnd_up (jcp.ic , simd_w);
503
+ }
504
+ }
495
505
496
506
jcp.ur_h = 1 ; /* no code-unrolling by h so far */
497
507
jcp.ur_w = 3 ;
@@ -549,6 +559,15 @@ status_t jit_sse41_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp,
549
559
return status::success;
550
560
}
551
561
562
+ void jit_sse41_conv_fwd_kernel_f32::init_scratchpad (
563
+ memory_tracking::registrar_t &scratchpad,
564
+ const jit_conv_conf_t &jcp) {
565
+ using namespace dnnl ::impl::memory_tracking::names;
566
+
567
+ if (jcp.prop_kind != backward_data && jcp.oc != jcp.oc_without_padding )
568
+ scratchpad.book <float >(key_conv_padded_bias, sizeof (float ) * jcp.oc );
569
+ }
570
+
552
571
} // namespace x64
553
572
} // namespace cpu
554
573
} // namespace impl
0 commit comments