-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix accuracy of max-pooling backpropagation for bfloat16 data #2386
Conversation
17ab064
to
2820609
Compare
make test |
make test |
2820609
to
73b85e2
Compare
73b85e2
to
dce529d
Compare
src/cpu/x64/jit_uni_pool_kernel.cpp
Outdated
@@ -18,6 +18,7 @@ | |||
#include <bitset> | |||
|
|||
#include "common/dnnl_thread.hpp" | |||
#include "common/memory_desc.hpp" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one should come with "cpu/cpu_pooling_pd.hpp", thus, not needed as standalone.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed.
src/cpu/x64/jit_uni_pool_kernel.cpp
Outdated
} | ||
|
||
template <cpu_isa_t isa> | ||
inline void jit_uni_pool_kernel<isa>::load32(const int idx, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's already a lot of existing routines to support loading. I highly recommend to change the loading/storing implementation to rely on io_injector, which is more flexible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, loading/storing functions need a refactoring. I did not know how to do that properly, and I did not know about io_injector. I have to investigate that. Could it be done as a separate task?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If doing it as a separate task, I definitely recommend refactor existing implementation first, and then apply f32 accumulation altogether with acc_mode support.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The io injector is now used to load/store tensor data, but indices are processed the same way as before because the io injector does not support these data types: it converts integers to floats during loading (however, it does not convert data if data are stored as s32; s32 and f32 are stored by one common function store_f32; it looks like a bug).
8f5d49e
to
8d3b54c
Compare
25cbb9a
to
f5d5a1d
Compare
make test |
d98dd7d
to
5734498
Compare
5734498
to
3ee89a1
Compare
make test |
fbf38e7
to
3abcafb
Compare
make test |
26092b9
to
828829f
Compare
8e4012d
to
393950f
Compare
make test |
393950f
to
5b90eee
Compare
make test |
src/cpu/x64/jit_uni_pool_kernel.hpp
Outdated
// While this is only required by the backward pass, the quirk above | ||
// is applied to the forward pass as well to keep things simpler. | ||
Opmask k_c_tail_mask = Opmask( | ||
4); // is shared with jit_io_multi_dt_helper_t and jit_uni_postops_injector_t |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Style: would be nice to have comment above for better readability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated
--attr-post-ops= | ||
|
||
--alg=max | ||
--tag=axb,aBx8b,aBx16b |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's drop aBx8b layout from here. It shouldn't be the case for backward at all. I'd drop 16b as well but it's fine to keep what we had before the change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
aBx8b is dropped
--dir=BWD_D | ||
--attr-acc-mode=relaxed | ||
--batch=set_all | ||
--batch=set_topologies |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's keep a smaller batch out of these two (I don't know which one it is).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
set_all is dropped
Do not use f32 accumulator in jit_uni_pooling for max pooling back propagation with bf16 if 'relaxed' or 'any' accumulation mode is specified. Use zero error threshold in tests for max pooling if 'strict' or 'f32' accumulation mode is specified.
5b90eee
to
5cc87a6
Compare
MFDNN-11050 bf16 backward max pooling returns incorrect results
MFDNN-11396 BF16 pooling_backward performance regression on SPR
As a result of a refactoring, MFDNN-12863 (JIT max pool implementation works incorrectly for small data types and large kernels) was also fixed.
Also, as it was recommended during the review, io injector is used to load/store tensor data.
It was found earlier (MFDNN-11050) that bf16 backward max pooling returns incorrect results. An initial fix of an accuracy led to significant performance regression (MFDNN-11396). That initial fix was rolled back.
The reason of an accuracy issue is that even a sum of relatively small numbers is not accurate, e.g. bf16(256.0)+bf16(1.0) is bf16(256.0). Summation can take place if some pooling strides are less than corresponding kernel sizes.
The current fix uses additional accumulation arrays of f32's, with one array per thread. The size of those arrays for src_diff is the same as for existing ncsp implementation (ncsp implementation creates arrays of f32's for dst_diff, src_diff and indices, reorders data and uses those arrays during calculations). The ncsp case is not affected by this PR (except changed load/store functions).
I have done some manual measurements on a machine with SPR processor. In some cases this implementation works faster than the original version, sometimes slower, but significantly better than not-optimized implementation (that was used after the first fix of MFDNN-11050).
The following tables contain performance data for axb and aBx16b layouts for the original implementation (main branch), the fixed version (this PR), and another implementation (that is used if the optimized implementation is skipped).
Scratch of a script used to run tests:
axb
aBx16b