-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to debug conv2d backward Unexpected page fault
error
#2864
Comments
Hi @hoshibara, thank you for providing verbose logs, that gives some perspective and tips, but it if you could provide ONEDNN_VERBOSE=all output (on the contrary to 1) that may give some more context for the team triaging the issue. If the problem was in the execution of backward path, it should report a create line for that backward convolution, and it would help to see if parameters matched and what implementation was used. It was also mentioned it started after upgrade to v3.7. Does it imply the previous version was working well? And if so, which version was that? And just in case if you are eager to help us with debugging, there are some more questions:
Answering these questions will help to narrow down the scope and potentially simplify the standalone oneDNN reproducer. Thank you. |
Hi @dzarukin, thank you for your patience and suggestions. It was also mentioned it started after upgrade to v3.7. Does it imply the previous version was working well? And if so, which version was that?I rollback the oneDNN integrated in the latest PyTorch and confirmed that this issue first appeared in diff --git a/shape_1.3_6_2.processed.log b/shape_1.3_7.processed.log
index d3bd144..25dad21 100644
--- a/shape_1.3_6_2.processed.log
+++ b/shape_1.3_7.processed.log
@@ -1,13 +1,14 @@
-onednn_verbose,v1,info,oneDNN v3.6.2 (commit 2eb3dd1082db767fab171e934c551c609008289a)
+onednn_verbose,v1,info,oneDNN v3.7.0 (commit 5e9224036021433d2577548ed0539fe9a53256bc)
onednn_verbose,v1,info,cpu,runtime:threadpool,nthr:56
onednn_verbose,v1,info,cpu,isa:Intel AVX-512 with float16, Intel DL Boost and bfloat16 support
onednn_verbose,v1,info,gpu,runtime:DPC++
+onednn_verbose,v1,info,gpu,engine,sycl gpu device count:1
onednn_verbose,v1,info,gpu,engine,0,backend:Level Zero,name:Intel(R) Data Center GPU Max 1550,driver_version:1.6.31294,binary_kernels:enabled
onednn_verbose,v1,info,graph,backend,0:dnnl_backend
onednn_verbose,v1,primitive,info,template:operation,engine,primitive,implementation,prop_kind,memory_descriptors,attributes,auxiliary,problem_desc,exec_time
onednn_verbose,v1,graph,info,template:operation,engine,partition_id,partition_kind,op_names,data_formats,logical_tensors,fpmath_mode,implementation,backend,exec_time
-onednn_verbose,v1,primitive,create:dispatch,gemm,gpu,gemm,jit:xe_hp:gemm:any,undef,src_a:f16::blocked:ab::f0 src_b:f16::blocked:ab::f0 dst:f16::blocked:ab::f0,,,1x1024:1024x1024,unsupported format tag,src/gpu/intel/jit/gemm/xe_hp_systolic_gemm.cpp:69
-onednn_verbose,v1,primitive,create:dispatch,gemm,gpu,gemm,jit:xe_hp:gemm:any,undef,src_a:f16::blocked:ab::f0 src_b:f16::blocked:ab::f0 dst:f16::blocked:ab::f0,,,1x1024:1024x1024,unsupported format tag,src/gpu/intel/jit/gemm/xe_hp_systolic_gemm.cpp:69
+onednn_verbose,v1,primitive,create:dispatch,gemm,gpu,gemm,jit:xe_hp:gemm:any,undef,src_a:f16::blocked:ab::f0 src_b:f16::blocked:ab::f0 dst:f16::blocked:ab::f0,,,1x1024:1024x1024,unsupported format tag,src/gpu/intel/jit/gemm/xe_hp_systolic_gemm.cpp:80
+onednn_verbose,v1,primitive,create:dispatch,gemm,gpu,gemm,jit:xe_hp:gemm:any,undef,src_a:f16::blocked:ab::f0 src_b:f16::blocked:ab::f0 dst:f16::blocked:ab::f0,,,1x1024:1024x1024,unsupported format tag,src/gpu/intel/jit/gemm/xe_hp_systolic_gemm.cpp:80
onednn_verbose,v1,info,gpu,gemm,consider:64x40,4x8x1,score:xxxxxx
onednn_verbose,v1,info,gpu,gemm,consider:64x32,4x8x1,score:xxxxxx
onednn_verbose,v1,info,gpu,gemm,consider:32x48,8x4x1,score:xxxxxx
@@ -33,7 +34,7 @@ onednn_verbose,v1,info,gpu,gemm,consider:64x32,4x8x1,score:xxxxxx
onednn_verbose,v1,info,gpu,gemm,consider:64x2,2x8x1,score:xxxxxx
onednn_verbose,v1,info,gpu,gemm,consider:64x4,8x2x1,score:xxxxxx
onednn_verbose,v1,info,gpu,gemm,consider:64x4,8x2x1,score:xxxxxx
-onednn_verbose,v1,primitive,create:dispatch,gemm,gpu,gemm,ocl:gemm_with_po:any,undef,src_a:f16::blocked:ab::f0 src_b:f16::blocked:ab::f0 dst:f16::blocked:ab::f0,,,1x1024:1024x1024,failed to create nested primitive jit:gemm:any,src/gpu/intel/ocl/gemm/gemm_with_post_ops.cpp:76
+onednn_verbose,v1,primitive,create:dispatch,gemm,gpu,gemm,ocl:gemm_with_po:any,undef,src_a:f16::blocked:ab::f0 src_b:f16::blocked:ab::f0 dst:f16::blocked:ab::f0,,,1x1024:1024x1024,failed to create nested jit:gemm:any primitive,src/gpu/intel/ocl/gemm/gemm_with_post_ops.cpp:76
onednn_verbose,v1,info,gpu,gemm,consider:64x40,4x8x1,score:xxxxxx
onednn_verbose,v1,info,gpu,gemm,consider:64x32,4x8x1,score:xxxxxx
onednn_verbose,v1,info,gpu,gemm,consider:32x48,8x4x1,score:xxxxxx
@@ -67,6 +68,7 @@ onednn_verbose,v1,primitive,exec,gpu,matmul,jit:gemm:any,undef,src:f16::blocked:
onednn_verbose,v1,primitive,create:cache_miss,gpu,convolution,jit:ir,forward_training,src:f64::blocked:abcd::f0 wei:f64::blocked:abcd::f0 bia:undef::undef::: dst:f64::blocked:abcd::f0,attr-scratchpad:user,alg:convolution_direct,mb1_ic64oc64_ih256oh256kh3sh1dh0ph1_iw256ow256kw3sw1dw0pw1,xxxxxx
onednn_verbose,v1,primitive,exec,gpu,convolution,jit:ir,forward_training,src:f64::blocked:abcd::f0 wei:f64::blocked:abcd::f0 bia:undef::undef::: dst:f64::blocked:abcd::f0,attr-scratchpad:user,alg:convolution_direct,mb1_ic64oc64_ih256oh256kh3sh1dh0ph1_iw256ow256kw3sw1dw0pw1,xxxxxx
onednn_verbose,v1,primitive,create:cache_miss,gpu,convolution,jit:ir,backward_data,src:f64::blocked:abcd::f0 wei:f64::blocked:abcd::f0 bia:undef::undef::: dst:f64::blocked:abcd::f0,attr-scratchpad:user,alg:convolution_direct,mb1_ic64oc64_ih256oh256kh3sh1dh0ph1_iw256ow256kw3sw1dw0pw1,xxxxxx
-onednn_verbose,v1,primitive,exec,gpu,convolution,jit:ir,backward_data,src:f64::blocked:abcd::f0 wei:f64::blocked:abcd::f0 bia:undef::undef::: dst:f64::blocked:abcd::f0,attr-scratchpad:user,alg:convolution_direct,mb1_ic64oc64_ih256oh256kh3sh1dh0ph1_iw256ow256kw3sw1dw0pw1,xxxxxx
-onednn_verbose,v1,primitive,create:cache_miss,gpu,convolution,jit:ir,backward_weights,src:f64::blocked:abcd::f0 wei:f64::blocked:abcd::f0 bia:undef::undef::: dst:f64::blocked:abcd::f0,attr-scratchpad:user,alg:convolution_direct,mb1_ic64oc64_ih256oh256kh3sh1dh0ph1_iw256ow256kw3sw1dw0pw1,xxxxxx
-onednn_verbose,v1,primitive,exec,gpu,convolution,jit:ir,backward_weights,src:f64::blocked:abcd::f0 wei:f64::blocked:abcd::f0 bia:undef::undef::: dst:f64::blocked:abcd::f0,attr-scratchpad:user,alg:convolution_direct,mb1_ic64oc64_ih256oh256kh3sh1dh0ph1_iw256ow256kw3sw1dw0pw1,xxxxxx
+FATAL: Unexpected page fault from GPU at 0xff0000000ea3a000, ctx_id: 1 (CCS) type: 0 (NotPresent), level: 1 (PDE), access: 0 (Read), banned: 1, aborting.
+FATAL: Unexpected page fault from GPU at 0xff0000000ea3a000, ctx_id: 1 (CCS) type: 0 (NotPresent), level: 1 (PDE), access: 0 (Read), banned: 1, aborting.
+Abort was called at 288 line in file:
+./shared/source/os_interface/linux/drm_neo.cpp Does the issue happen if and only if there's this specific matmul in there? What happens if matmul op gets removed?Yes. If I remove the Does the issue happen with f16 matmul on any other shape? Like 1x256:256x256? Or maybe 16x16:16x16?Yes. The smallest shape I found that can reproduce the issue is Does it happen with f32 or f64 matmul used instead of f16?No. If I use f32 or f64 for matmul with the same tensor shapes, conv2d works fine and no error occurs. Does it happen for this specific convolution shape? Will any other fail?Yes. The smallest convolution shape I found so far that reproduces the issue is as follows: x_cpu shape: [1, 32, 256, 256]
grad_cpu shape: [1, 9, 256, 256]
conv_cpu = nn.Conv2d(32, 9, kernel_size=3, stride=1, padding=1, bias=False).double() Thank you again for your time and support! |
Can't reproduce it on MTL iGPU for oneDNN v3.7.0, but reproduced on Intel(R) Data Center GPU Max 1100,driver_version:1.6.33022 with oneDNN v3.7.1 (commit 8d263e6) |
@hoshibara Thank you for the data and the information. @echeresh @kealan-barbieri Sounds like the real and not straightforward problem with conv bwd_d, could you, please, assist further? |
@hoshibara Thanks for the details. To rule out buffer size mismatch - can you please collect trace with onetrace tool? After that you can run your application as:
It should trace all Level Zero API calls, and we can check all memory allocations to see if they have the correct sizes. |
Hi @echeresh Command: $ ZE_AFFINITY_MASK=2 onetrace -c python onednn-reproducer.py > onetrace.log 2>&1
[1] 1226761 IOT instruction (core dumped) ZE_AFFINITY_MASK=2 -c python onednn-reproducer.py > onetrace.log 2>&1 Log: Also, I noticed a strange thing that may help in locating the problem. $ ZE_AFFINITY_MASK=3 python onednn-reproducer.py (Can properly finish)
$ ZE_AFFINITY_MASK=2 python onednn-reproducer.py (Unexpected page fault)
Traceback (most recent call last):
File "/workspace1/xingyuan/20250313-onednn-pytorch-reproducer/onednn-reproducer.py", line 33, in <module>
y_dpcpp.backward(grad_dpcpp)
File "/workspace1/xingyuan/20250313-onednn-pytorch-reproducer/hoshibara-pytorch/torch/_tensor.py", line 648, in backward
torch.autograd.backward(
File "/workspace1/xingyuan/20250313-onednn-pytorch-reproducer/hoshibara-pytorch/torch/autograd/__init__.py", line 353, in backward
_engine_run_backward(
File "/workspace1/xingyuan/20250313-onednn-pytorch-reproducer/hoshibara-pytorch/torch/autograd/graph.py", line 824, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: UR backend failed. UR backend returns:40 (UR_RESULT_ERROR_OUT_OF_RESOURCES)
FATAL: Unexpected page fault from GPU at 0xff0000004bdf7000, ctx_id: 1 (CCS) type: 0 (NotPresent), level: 2 (PDP), access: 0 (Read), banned: 1, aborting.
FATAL: Unexpected page fault from GPU at 0xff0000004bdf7000, ctx_id: 1 (CCS) type: 0 (NotPresent), level: 2 (PDP), access: 0 (Read), banned: 1, aborting.
Abort was called at 288 line in file:
./shared/source/os_interface/linux/drm_neo.cpp
[1] 792005 IOT instruction (core dumped) ZE_AFFINITY_MASK=2 python onednn-reproducer.py $ sudo clinfo | grep GPU
Device Name Intel(R) Data Center GPU Max 1550
Device Type GPU
Device Name Intel(R) Data Center GPU Max 1550
Device Type GPU
Device Name Intel(R) Data Center GPU Max 1550
Device Type GPU
Device Name Intel(R) Data Center GPU Max 1550
Device Type GPU
Device Name Intel(R) Data Center GPU Max 1550
Device Type GPU
Device Name Intel(R) Data Center GPU Max 1550
Device Type GPU
Device Name Intel(R) Data Center GPU Max 1550
Device Type GPU
Device Name Intel(R) Data Center GPU Max 1550
Device Type GPU |
@hoshibara Thanks for the log. I don't see any obvious buffer size mismatches. For reference, we have the following kernel sequence (also listing their arguments) at the end, before the page fault:
Based on the previous analysis, I assume the issue is in the first three kernels (they are related to backward by data convolution). What is interesting is that the faulting address @hoshibara As for the next steps I suggest the following:
As for further debugging, we usually try some simple tricks:
|
Hi, @echeresh
ONEDNN_VERBOSE=2 ZE_AFFINITY_MASK=0 onetrace -c python onednn-reproducer.py > onetrace.verbose-2.err.log 2>&1 The log is as follows: In this log, I observed the page fault occurring during the
|
@hoshibara Thanks, the faulting address looks similar, ~35 MB after the scratchpad bounds.
Okay, I guess PyTorch doesn't support OpenCL backend...
Yes - disable So far, I have one suggestion only: can you please collect and share oneDNN kernel dumps ( If this doesn't clarify anything, as the next thing, I would try to reproduce and debug it on our side - for that can you please share links/steps you used to install and run PyTorch? |
@echeresh Here are my dump files: ONEDNN_VERBOSE=2 ZE_AFFINITY_MASK=0 ONEDNN_JIT_DUMP=1 python -u onednn-reproducer.py And my steps to install pytorch are as follow: CONDA_ENV_NAME=onednn-reproducer
conda create -n $CONDA_ENV_NAME python=3.10 -y
conda deactivate||source deactivate||true
conda activate $CONDA_ENV_NAME||source activate $CONDA_ENV_NAME||true
ONEAPI_HOME=/home/sdp/intel/oneapi
source ${ONEAPI_HOME}/compiler/latest/env/vars.sh
source ${ONEAPI_HOME}/tbb/latest/env/vars.sh
source ${ONEAPI_HOME}/umf/latest/env/vars.sh
source ${ONEAPI_HOME}/pti/latest/env/vars.sh
source ${ONEAPI_HOME}/ccl/latest/env/vars.sh
export USE_XPU=1
export PYTORCH_ENABLE_XPU_FALLBACK=0
export PYTORCH_DEBUG_XPU_FALLBACK=1
export USE_KINETO=1
export BUILD_SEPARATE_OPS=ON
export USE_AOT_DEVLIST='pvc'
export TORCH_XPU_ARCH_LIST='pvc'
export _GLIBCXX_USE_CXX11_ABI=1
export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"}
git clone https://github.com/pytorch/pytorch.git pytorch
cd pytorch
git checkout 037d7af778dc26878307825deb559e9c767ed2ba
cd ..
# you can modify onednn version here:
# pytorch/cmake/Modules/FindMKLDNN.cmake:52
cd pytorch
pip uninstall torch -y
pip uninstall torch -y
pip uninstall torch -y
git submodule sync
git submodule update --init --recursive
pip install -r requirements.txt
make triton
python setup.py develop
cd ..
|
@echeresh Here is an another easy way to setup a Pytorch env and reproduce the issue.
|
Can't reproduce the issue by running matmul and then running conv ops, for example:
|
You'd need two convolutions at least, on forward and one backward data (and maybe one backward weights as well), according to the messages above. |
@hoshibara @rcao8 Thanks. I was able to reproduce it and reduce to two primitives: MatMul (called once) and deconvolution (in a loop). The loop part has only one kernel (deconvolution) and doesn't have any kernel creations or memory allocations beyond the first iteration. So we basically execute the same kernel with the same buffers over and over again. The fact that it occasionally generates a page fault after some time points to a hardware bug. I'll work on a C++ reproducer next week to try different things to understand what triggers this behavior. import torch
import torch.nn as nn
cpu_device = torch.device("cpu")
dpcpp_device = torch.device("xpu")
O, I, H, W = 128, 128, 128, 128
def mm():
m, n = 1, 100
a = torch.rand([m, n], device=dpcpp_device, dtype=torch.float16)
b = torch.rand([n, n], device=dpcpp_device, dtype=torch.float16)
torch.mm(a, b)
conv_type=torch.double
x_cpu = torch.randn([1, O, H, W], dtype=conv_type, device=cpu_device)
w_cpu = torch.randn([O, I, 1, 1], dtype=conv_type, device=cpu_device)
x_cpu = x_cpu.to(memory_format=torch.channels_last)
deconv_cpu = nn.ConvTranspose2d(O, I, kernel_size=1, stride=1, padding=0, bias=False).double()
x_dpcpp = x_cpu.to(dpcpp_device)
deconv_dpcpp = deconv_cpu.to(dpcpp_device)
mm()
for i in range(100):
y_dpcpp = deconv_dpcpp(x_dpcpp) Page fault: $ ONEDNN_VERBOSE=1 python 1.py
onednn_verbose,v1,info,oneDNN v3.7.1 (commit 8d263e693366ef8db40acc569cc7d8edf644556d)
onednn_verbose,v1,info,cpu,runtime:threadpool,nthr:52
onednn_verbose,v1,info,cpu,isa:Intel AVX-512 with float16, Intel DL Boost and bfloat16 support
onednn_verbose,v1,info,gpu,runtime:DPC++
onednn_verbose,v1,info,gpu,engine,sycl gpu device count:1
onednn_verbose,v1,info,gpu,engine,0,backend:Level Zero,name:Intel(R) Data Center GPU Max 1550,driver_version:1.3.30049,binary_kernels:enabled
onednn_verbose,v1,info,graph,backend,0:dnnl_backend
onednn_verbose,v1,primitive,info,template:operation,engine,primitive,implementation,prop_kind,memory_descriptors,attributes,auxiliary,problem_desc,exec_time
onednn_verbose,v1,graph,info,template:operation,engine,partition_id,partition_kind,op_names,data_formats,logical_tensors,fpmath_mode,implementation,backend,exec_time
onednn_verbose,v1,primitive,exec,gpu,matmul,jit:gemm:any,undef,src:f16::blocked:ab::f0 wei:f16::blocked:ab::f0 dst:f16::blocked:ab::f0,attr-scratchpad:user,,1x100:100x100,0.215088
onednn_verbose,v1,primitive,exec,gpu,deconvolution,conv:any+jit:ir,forward_training,src:f64::blocked:acdb::f0 wei:f64::blocked:bacd::f0 bia:undef::undef::: dst:f64::blocked:acdb::f0,attr-scratchpad:user,alg:deconvolution_direct,mb1_ic128oc128_ih128oh128kh1sh1dh0ph0_iw128ow128kw1sw1dw0pw0,0.203125
onednn_verbose,v1,primitive,exec,gpu,deconvolution,conv:any+jit:ir,forward_training,src:f64::blocked:acdb::f0 wei:f64::blocked:bacd::f0 bia:undef::undef::: dst:f64::blocked:acdb::f0,attr-scratchpad:user,alg:deconvolution_direct,mb1_ic128oc128_ih128oh128kh1sh1dh0ph0_iw128ow128kw1sw1dw0pw0,0.0830078
onednn_verbose,v1,primitive,exec,gpu,deconvolution,conv:any+jit:ir,forward_training,src:f64::blocked:acdb::f0 wei:f64::blocked:bacd::f0 bia:undef::undef::: dst:f64::blocked:acdb::f0,attr-scratchpad:user,alg:deconvolution_direct,mb1_ic128oc128_ih128oh128kh1sh1dh0ph0_iw128ow128kw1sw1dw0pw0,0.0688477
onednn_verbose,v1,primitive,exec,gpu,deconvolution,conv:any+jit:ir,forward_training,src:f64::blocked:acdb::f0 wei:f64::blocked:bacd::f0 bia:undef::undef::: dst:f64::blocked:acdb::f0,attr-scratchpad:user,alg:deconvolution_direct,mb1_ic128oc128_ih128oh128kh1sh1dh0ph0_iw128ow128kw1sw1dw0pw0,0.0671387
onednn_verbose,v1,primitive,exec,gpu,deconvolution,conv:any+jit:ir,forward_training,src:f64::blocked:acdb::f0 wei:f64::blocked:bacd::f0 bia:undef::undef::: dst:f64::blocked:acdb::f0,attr-scratchpad:user,alg:deconvolution_direct,mb1_ic128oc128_ih128oh128kh1sh1dh0ph0_iw128ow128kw1sw1dw0pw0,0.0849609
onednn_verbose,v1,primitive,exec,gpu,deconvolution,conv:any+jit:ir,forward_training,src:f64::blocked:acdb::f0 wei:f64::blocked:bacd::f0 bia:undef::undef::: dst:f64::blocked:acdb::f0,attr-scratchpad:user,alg:deconvolution_direct,mb1_ic128oc128_ih128oh128kh1sh1dh0ph0_iw128ow128kw1sw1dw0pw0,0.103027
onednn_verbose,v1,primitive,exec,gpu,deconvolution,conv:any+jit:ir,forward_training,src:f64::blocked:acdb::f0 wei:f64::blocked:bacd::f0 bia:undef::undef::: dst:f64::blocked:acdb::f0,attr-scratchpad:user,alg:deconvolution_direct,mb1_ic128oc128_ih128oh128kh1sh1dh0ph0_iw128ow128kw1sw1dw0pw0,0.064209
onednn_verbose,v1,primitive,exec,gpu,deconvolution,conv:any+jit:ir,forward_training,src:f64::blocked:acdb::f0 wei:f64::blocked:bacd::f0 bia:undef::undef::: dst:f64::blocked:acdb::f0,attr-scratchpad:user,alg:deconvolution_direct,mb1_ic128oc128_ih128oh128kh1sh1dh0ph0_iw128ow128kw1sw1dw0pw0,0.0620117
onednn_verbose,v1,primitive,exec,gpu,deconvolution,conv:any+jit:ir,forward_training,src:f64::blocked:acdb::f0 wei:f64::blocked:bacd::f0 bia:undef::undef::: dst:f64::blocked:acdb::f0,attr-scratchpad:user,alg:deconvolution_direct,mb1_ic128oc128_ih128oh128kh1sh1dh0ph0_iw128ow128kw1sw1dw0pw0,0.0629883
onednn_verbose,v1,primitive,exec,gpu,deconvolution,conv:any+jit:ir,forward_training,src:f64::blocked:acdb::f0 wei:f64::blocked:bacd::f0 bia:undef::undef::: dst:f64::blocked:acdb::f0,attr-scratchpad:user,alg:deconvolution_direct,mb1_ic128oc128_ih128oh128kh1sh1dh0ph0_iw128ow128kw1sw1dw0pw0,0.996826
FATAL: Unexpected page fault from GPU at 0xff00fffffff9e000, ctx_id: 1 (CCS) type: 0 (NotPresent), level: 1 (PDE), access: 0 (Read), banned: 1, aborting.
FATAL: Unexpected page fault from GPU at 0xff00fffffff9e000, ctx_id: 1 (CCS) type: 0 (NotPresent), level: 1 (PDE), access: 0 (Read), banned: 1, aborting.
Abort was called at 287 line in file:
./shared/source/os_interface/linux/drm_neo.cpp
Aborted (core dumped) |
@hoshibara @rcao8 The preliminary conclusion is that this is caused by a known double data type related hardware bug in PVC. I was able to reproduce it in benchdnn so it's completely on our side from this point - we'll investigate it and prepare a fix. A fix should be available in some future oneDNN release. As a mitigation, I see that overriding the thread arbitration policy makes the issue non-reproducible but it may come with a performance cost and may still produce the same issue but at a more rare rate: $ OverrideThreadArbitrationPolicy=2 NEOReadDebugKeys=1 PrintDebugSettings=1 python 1.py
Non-default value of debug variable: PrintDebugSettings = 1
Non-default value of debug variable: OverrideThreadArbitrationPolicy = 2
... |
@hoshibara @rcao8 We were able to confirm that the PVC HW bug @echeresh mentioned is the culprit. The linked PR (#2935) should fix the issue you're seeing. |
Hi, I recently encountered an issue related to conv2d backward.
After updating oneDNN to 3.7.0, conv2d backward throws
Unexpected page fault
after a specific shape matmul operation.I tried to reproduce this issue with C++ according oneDNN verbose logs from run-able tensor shape, but couldn't replicate the issue.
I guess it's caused by operator inconsistencies generated under specific shapes.
Could I get your suggestions on how to further investigate or debug this issue?
Pytorch reproducer
oneDNN Verbose
The text was updated successfully, but these errors were encountered: