Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

src: cpu: x64: jit_brdgmm_dw_conw: make brdgmm support fusing sum op #2337

Merged
merged 1 commit into from
Mar 20, 2025

Conversation

amakarev
Copy link
Contributor

@amakarev amakarev commented Jan 3, 2025

Brdgmm kernel could not be selected when using the --attr-post-ops=sum flag. jit_dw was used instead.
This PR makes it possible to use brdgmm kernel with sum post-op

The changes look trivial, In fact, there was already support, but brdgmm could not be dispatched because the function(has_default_values) signature was updated a long time ago but the function call was not updated. As a result, data_type_undef was always used inside, which prevented the brdgmm kernel from being selected

@amakarev amakarev added the platform:cpu-x64 Intel64/AMD64 processors. Codeowner: @oneapi-src/onednn-cpu-x64 label Jan 3, 2025
@amakarev amakarev requested a review from a team as a code owner January 3, 2025 12:23
@amakarev amakarev force-pushed the amakarev/brdgmm-support-fusing-sum-op branch from 5d7c753 to d4a0ca9 Compare January 3, 2025 13:59
Copy link
Contributor

@inteldimitrius inteldimitrius left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@dzarukin
Copy link
Contributor

dzarukin commented Jan 3, 2025

dst_dt argument is responsible to support s8 sum post-op while dst data type is u8, or vice versa. It doesn't have anything about enabling the sum post-op per se. undef data type passed requires sum post-op to check if it uses exactly the same data type as destination tensor, not allowing divergence from the first sentence.

It would be nice to see the following:

  • There's actually some coverage for brdgmm + sum post+op in ci (or some nightly) file.
  • The support flipped from jit to brgemm for a regular sum post-op case (same data type) and for a sum_dt != dst_dt (updating description with verbose before and after should be fine).

@amakarev
Copy link
Contributor Author

amakarev commented Jan 3, 2025

dst_dt argument is responsible to support s8 sum post-op while dst data type is u8, or vice versa. It doesn't have anything about enabling the sum post-op per se. undef data type passed requires sum post-op to check if it uses exactly the same data type as destination tensor, not allowing divergence from the first sentence.

It would be nice to see the following:

  • There's actually some coverage for brdgmm + sum post+op in ci (or some nightly) file.
  • The support flipped from jit to brgemm for a regular sum post-op case (same data type) and for a sum_dt != dst_dt (updating description with verbose before and after should be fine).

Got it, thanks for the clarification

@amakarev amakarev changed the title src: cpu: x64: jit_brdgmm_dw_conw: make brdgmm support fusing sum op [WIP] src: cpu: x64: jit_brdgmm_dw_conw: make brdgmm support fusing sum op Jan 3, 2025
@amakarev
Copy link
Contributor Author

before

onednn_verbose,v1,info,oneDNN v3.8.0 (commit f081ebd0ed982da5569bc14de3d2ae37fa7b15dd)
onednn_verbose,v1,info,cpu,runtime:OpenMP,nthr:384
onednn_verbose,v1,info,cpu,isa:Intel AVX-512 with float16, Intel DL Boost and bfloat16 support and Intel AMX with bfloat16 and 8-bit integer support
onednn_verbose,v1,info,gpu,runtime:none
onednn_verbose,v1,info,graph,backend,0:dnnl_backend
onednn_verbose,v1,primitive,info,template:operation,engine,primitive,implementation,prop_kind,memory_descriptors,attributes,auxiliary,problem_desc,exec_time
onednn_verbose,v1,graph,info,template:operation,engine,partition_id,partition_kind,op_names,data_formats,logical_tensors,fpmath_mode,implementation,backend,exec_time
onednn_verbose,v1,primitive,create:dispatch,convolution,cpu,convolution,brdgmm_dw:undef,forward_inference,src:f32::blocked:acdb::f0 wei:f32:a:any:any::f0 bia:undef::undef::: dst:f32::blocked:acdb::f0,attr-scratchpad:user attr-post-ops:sum:1:0:f32,alg:convolution_direct,g32mb2_ic32oc32_ih7oh5kh3sh1dh0ph0_iw7ow5kw3sw1dw0pw0,unsupported attribute,src/cpu/x64/jit_brdgmm_dw_conv.cpp:156
onednn_verbose,v1,primitive,create:dispatch,convolution,cpu,convolution,ip:any+,forward_inference,src:f32::blocked:acdb::f0 wei:f32:a:any:any::f0 bia:undef::undef::: dst:f32::blocked:acdb::f0,attr-scratchpad:user attr-post-ops:sum:1:0:f32,alg:convolution_direct,g32mb2_ic32oc32_ih7oh5kh3sh1dh0ph0_iw7ow5kw3sw1dw0pw0,unsupported attribute,src/cpu/x64/ip_convolution.hpp:200
onednn_verbose,v1,primitive,create:dispatch,convolution,heuristic fail: no optimization for depthwise convolution,src/cpu/x64/jit_brgemm_conv_utils.cpp:1815
onednn_verbose,v1,primitive,create:dispatch,convolution,heuristic fail: no optimization for depthwise convolution,src/cpu/x64/jit_brgemm_conv_utils.cpp:1815
onednn_verbose,v1,primitive,create:dispatch,convolution,heuristic fail: no optimization for depthwise convolution,src/cpu/x64/jit_brgemm_conv_utils.cpp:1815
onednn_verbose,v1,primitive,create:dispatch,convolution,heuristic fail: no optimization for depthwise convolution,src/cpu/x64/jit_brgemm_conv_utils.cpp:1815
onednn_verbose,v1,primitive,create:cache_miss,cpu,convolution,jit_dw:avx512_core,forward_inference,src:f32::blocked:acdb::f0 wei:f32:a:blocked:Abcde16a::f0 bia:undef::undef::: dst:f32::blocked:acdb::f0,attr-scratchpad:user attr-post-ops:sum:1:0:f32,alg:convolution_direct,g32mb2_ic32oc32_ih7oh5kh3sh1dh0ph0_iw7ow5kw3sw1dw0pw0,0.306152
onednn_verbose,v1,primitive,create:cache_hit,cpu,convolution,jit_dw:avx512_core,forward_inference,src:f32::blocked:acdb::f0 wei:f32:a:blocked:Abcde16a::f0 bia:undef::undef::: dst:f32::blocked:acdb::f0,attr-scratchpad:user attr-post-ops:sum:1:0:f32,alg:convolution_direct,g32mb2_ic32oc32_ih7oh5kh3sh1dh0ph0_iw7ow5kw3sw1dw0pw0,0.0090332
onednn_verbose,v1,primitive,create:dispatch,reorder,src has a bad number of dimensions 5,src/cpu/x64/matmul/brgemm_matmul_reorders.cpp:297
onednn_verbose,v1,primitive,create:dispatch,reorder,cpu,reorder,jit_direct_copy:uni,undef,src:f32::blocked:abcde::f0 dst:f32::blocked:Abcde16a::f0,,,32x1x1x3x3,memory formats for src and dst tensors do not match,src/cpu/x64/jit_uni_reorder_direct_copy.cpp:310
onednn_verbose,v1,primitive,create:cache_miss,cpu,reorder,jit:blk,undef,src:f32::blocked:abcde::f0 dst:f32::blocked:Abcde16a::f0,,,32x1x1x3x3,0.25
onednn_verbose,v1,primitive,exec,cpu,reorder,jit:blk,undef,src:f32::blocked:abcde::f0 dst:f32::blocked:Abcde16a::f0,,,32x1x1x3x3,16.093
onednn_verbose,v1,primitive,create:dispatch,reorder,small shapes fall back,src/cpu/x64/matmul/brgemm_matmul_reorders.cpp:224
onednn_verbose,v1,primitive,create:dispatch,reorder,cpu,reorder,jit_direct_copy:uni,undef,src:f32::blocked:abcd::f0 dst:f32::blocked:acdb::f0,,,2x32x7x7,memory formats for src and dst tensors do not match,src/cpu/x64/jit_uni_reorder_direct_copy.cpp:310
onednn_verbose,v1,primitive,create:cache_miss,cpu,reorder,jit:uni,undef,src:f32::blocked:abcd::f0 dst:f32::blocked:acdb::f0,,,2x32x7x7,0.281006
onednn_verbose,v1,primitive,exec,cpu,reorder,jit:uni,undef,src:f32::blocked:abcd::f0 dst:f32::blocked:acdb::f0,,,2x32x7x7,81.0479
onednn_verbose,v1,primitive,create:dispatch,reorder,small shapes fall back,src/cpu/x64/matmul/brgemm_matmul_reorders.cpp:224
onednn_verbose,v1,primitive,create:dispatch,reorder,cpu,reorder,jit_direct_copy:uni,undef,src:f32::blocked:abcd::f0 dst:f32::blocked:acdb::f0,,,2x32x5x5,memory formats for src and dst tensors do not match,src/cpu/x64/jit_uni_reorder_direct_copy.cpp:310
onednn_verbose,v1,primitive,create:cache_miss,cpu,reorder,jit:uni,undef,src:f32::blocked:abcd::f0 dst:f32::blocked:acdb::f0,,,2x32x5x5,0.242188
onednn_verbose,v1,primitive,exec,cpu,reorder,jit:uni,undef,src:f32::blocked:abcd::f0 dst:f32::blocked:acdb::f0,,,2x32x5x5,5.53198
onednn_verbose,v1,primitive,exec,cpu,convolution,jit_dw:avx512_core,forward_inference,src:f32::blocked:acdb::f0 wei:f32:a:blocked:Abcde16a::f0 bia:undef::undef::: dst:f32::blocked:acdb::f0,attr-scratchpad:user attr-post-ops:sum:1:0:f32,alg:convolution_direct,g32mb2_ic32oc32_ih7oh5kh3sh1dh0ph0_iw7ow5kw3sw1dw0pw0,11.106
onednn_verbose,v1,primitive,create:dispatch,reorder,small shapes fall back,src/cpu/x64/matmul/brgemm_matmul_reorders.cpp:224
onednn_verbose,v1,primitive,create:dispatch,reorder,cpu,reorder,jit_direct_copy:uni,undef,src:f32::blocked:acdb::f0 dst:f32::blocked:abcd::f0,,,2x32x5x5,memory formats for src and dst tensors do not match,src/cpu/x64/jit_uni_reorder_direct_copy.cpp:310
onednn_verbose,v1,primitive,create:cache_miss,cpu,reorder,jit:uni,undef,src:f32::blocked:acdb::f0 dst:f32::blocked:abcd::f0,,,2x32x5x5,0.315918
onednn_verbose,v1,primitive,exec,cpu,reorder,jit:uni,undef,src:f32::blocked:acdb::f0 dst:f32::blocked:abcd::f0,,,2x32x5x5,16.175
0:PASSED __REPRO: --conv --allow-enum-tags-only=false --dir=FWD_I --stag=acdb --dtag=acdb --attr-post-ops=sum:1:0:f32 --attr-scratchpad=user g32ic32ih7oc32oh5kh3ph0
tests:1 passed:1 skipped:0 mistrusted:0 unimplemented:0 invalid_arguments:0 failed:0 listed:0
total: 0.27s; fill: 0.18s (65%); compute_ref: 0.01s (4%); compare: 0.04s (17%);

after

./build/tests/benchdnn/benchdnn  --conv --reset --allow-enum-tags-only=0 --engine=cpu --dir=FWD_I --alg=direct --dt=f32:f32:f32 --stag=acdb --
wtag=any --dtag=acdb --attr-post-ops=sum:1.0:0:f32 --attr-scales= --attr-zero-points= --attr-scratchpad=user g32mb2_ic32oc32_ih7oh5kh3sh1dh0ph0_iw7ow5kw
3sw1dw0pw0
onednn_verbose,v1,info,oneDNN v3.7.0 (commit 43822b23f38745cb4ab6d6aced221f2c0b4c9e15)
onednn_verbose,v1,info,cpu,runtime:OpenMP,nthr:384
onednn_verbose,v1,info,cpu,isa:Intel AVX-512 with float16, Intel DL Boost and bfloat16 support and Intel AMX with bfloat16 and 8-bit integer support
onednn_verbose,v1,info,gpu,runtime:none
onednn_verbose,v1,info,graph,backend,0:dnnl_backend
onednn_verbose,v1,primitive,info,template:operation,engine,primitive,implementation,prop_kind,memory_descriptors,attributes,auxiliary,problem_desc,exec_time
onednn_verbose,v1,graph,info,template:operation,engine,partition_id,partition_kind,op_names,data_formats,logical_tensors,fpmath_mode,implementation,backend,exec_time
onednn_verbose,v1,primitive,create:cache_miss,cpu,convolution,brdgmm_dw:avx512_core,forward_inference,src:f32::blocked:acdb::f0 wei:f32:a:blocked:debcA16a::f0 bia:undef::undef::: dst:f32::blocked:acdb::f0,attr-scratchpad:user attr-post-ops:sum:1:0:f32,alg:convolution_direct,g32mb2_ic32oc32_ih7oh5kh3sh1dh0ph0_iw7ow5kw3sw1dw0pw0,0.323975
onednn_verbose,v1,primitive,create:cache_hit,cpu,convolution,brdgmm_dw:avx512_core,forward_inference,src:f32::blocked:acdb::f0 wei:f32:a:blocked:debcA16a::f0 bia:undef::undef::: dst:f32::blocked:acdb::f0,attr-scratchpad:user attr-post-ops:sum:1:0:f32,alg:convolution_direct,g32mb2_ic32oc32_ih7oh5kh3sh1dh0ph0_iw7ow5kw3sw1dw0pw0,0.00805664
onednn_verbose,v1,primitive,create:dispatch,reorder,src has a bad number of dimensions 5,src/cpu/x64/matmul/brgemm_matmul_reorders.cpp:294
onednn_verbose,v1,primitive,create:dispatch,reorder,cpu,reorder,jit_direct_copy:uni,undef,src:f32::blocked:abcde::f0 dst:f32::blocked:debcA16a::f0,,,32x1x1x3x3,memory formats for src and dst tensors do not match,src/cpu/x64/jit_uni_reorder_direct_copy.cpp:301
onednn_verbose,v1,primitive,create:cache_miss,cpu,reorder,jit:uni,undef,src:f32::blocked:abcde::f0 dst:f32::blocked:debcA16a::f0,,,32x1x1x3x3,0.294922
onednn_verbose,v1,primitive,exec,cpu,reorder,jit:uni,undef,src:f32::blocked:abcde::f0 dst:f32::blocked:debcA16a::f0,,,32x1x1x3x3,0.0100098
onednn_verbose,v1,primitive,create:cache_miss,cpu,reorder,brgemm_matmul_copy_reorder_t,undef,src:f32::blocked:abcd::f0 dst:f32::blocked:acdb::f0,,,2x32x7x7,0.332031
onednn_verbose,v1,primitive,exec,cpu,reorder,brgemm_matmul_copy_reorder_t,undef,src:f32::blocked:abcd::f0 dst:f32::blocked:acdb::f0,,,2x32x7x7,26.0642
onednn_verbose,v1,primitive,create:cache_miss,cpu,reorder,brgemm_matmul_copy_reorder_t,undef,src:f32::blocked:abcd::f0 dst:f32::blocked:acdb::f0,,,2x32x5x5,0.346924
onednn_verbose,v1,primitive,exec,cpu,reorder,brgemm_matmul_copy_reorder_t,undef,src:f32::blocked:abcd::f0 dst:f32::blocked:acdb::f0,,,2x32x5x5,16.6929
onednn_verbose,v1,primitive,exec,cpu,convolution,brdgmm_dw:avx512_core,forward_inference,src:f32::blocked:acdb::f0 wei:f32:a:blocked:debcA16a::f0 bia:undef::undef::: dst:f32::blocked:acdb::f0,attr-scratchpad:user attr-post-ops:sum:1:0:f32,alg:convolution_direct,g32mb2_ic32oc32_ih7oh5kh3sh1dh0ph0_iw7ow5kw3sw1dw0pw0,16.2871
onednn_verbose,v1,primitive,create:cache_miss,cpu,reorder,brgemm_matmul_copy_reorder_t,undef,src:f32::blocked:acdb::f0 dst:f32::blocked:abcd::f0,,,2x32x5x5,0.339844
onednn_verbose,v1,primitive,exec,cpu,reorder,brgemm_matmul_copy_reorder_t,undef,src:f32::blocked:acdb::f0 dst:f32::blocked:abcd::f0,,,2x32x5x5,21.1331
0:PASSED __REPRO: --conv --allow-enum-tags-only=false --dir=FWD_I --stag=acdb --dtag=acdb --attr-post-ops=sum:1:0:f32 --attr-scratchpad=user g32ic32ih7oc32oh5kh3ph0
tests:1 passed:1 skipped:0 mistrusted:0 unimplemented:0 invalid_arguments:0 failed:0 listed:0
total: 0.36s; fill: 0.10s (28%); compute_ref: 0.03s (8%); compare: 0.19s (52%);

@amakarev amakarev changed the title [WIP] src: cpu: x64: jit_brdgmm_dw_conw: make brdgmm support fusing sum op src: cpu: x64: jit_brdgmm_dw_conw: make brdgmm support fusing sum op Mar 13, 2025
@dzarukin
Copy link
Contributor

Thanks for providing the logs, I think I see what's going on now.
If user doesn't specify the data type for sum_dt, everything dispatches fine:

onednn_verbose,v1,primitive,create:cache_miss,cpu,convolution,brdgmm_dw:avx512_core,forward_inference,src:f32::blocked:acdb::f0 wei:f32:a:blocked:debcA16a::f0 bia:undef::undef::: dst:f32::blocked:acdb::f0,attr-scratchpad:user attr-post-ops:sum,alg:convolution_direct,g32mb2_ic32oc32_ih7oh5kh3sh1dh0ph0_iw7ow5kw3sw1dw0pw0,0.364014

but once it's specified, even to the actual dst_dt, it stops dispatching properly because of the check you are modifying:

onednn_verbose,v1,primitive,create:dispatch,convolution,cpu,convolution,brdgmm_dw:undef,forward_inference,src:f32::blocked:acdb::f0 wei:f32:a:any:any::f0 bia:undef::undef::: dst:f32::blocked:acdb::f0,attr-scratchpad:user attr-post-ops:sum:1:0:f32,alg:convolution_direct,g32mb2_ic32oc32_ih7oh5kh3sh1dh0ph0_iw7ow5kw3sw1dw0pw0,unsupported attribute,src/cpu/x64/jit_brdgmm_dw_conv.cpp:156

It looks quite unexpected to me, but it's probably not your war, so I'm fine to pass this change.

Hope you verified it's actually faster with brdgmm because the verbose from you doesn't show it is (though it's a first run, but still).

@mgouicem, we have a debate to talk through.

@amakarev amakarev force-pushed the amakarev/brdgmm-support-fusing-sum-op branch from d4a0ca9 to 8924b67 Compare March 20, 2025 15:04
@amakarev amakarev requested a review from a team as a code owner March 20, 2025 15:04
@amakarev amakarev merged commit 79273fb into main Mar 20, 2025
12 checks passed
@amakarev amakarev deleted the amakarev/brdgmm-support-fusing-sum-op branch March 20, 2025 17:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
platform:cpu-x64 Intel64/AMD64 processors. Codeowner: @oneapi-src/onednn-cpu-x64
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants