Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to debug conv2d backward Unexpected page fault error #2864

Closed
hoshibara opened this issue Mar 12, 2025 · 17 comments · Fixed by #2935
Closed

How to debug conv2d backward Unexpected page fault error #2864

hoshibara opened this issue Mar 12, 2025 · 17 comments · Fixed by #2935
Assignees
Labels
bug A confirmed library bug platform:gpu-intel Codeowner: @oneapi-src/onednn-gpu-intel

Comments

@hoshibara
Copy link

hoshibara commented Mar 12, 2025

Hi, I recently encountered an issue related to conv2d backward.
After updating oneDNN to 3.7.0, conv2d backward throws Unexpected page fault after a specific shape matmul operation.
I tried to reproduce this issue with C++ according oneDNN verbose logs from run-able tensor shape, but couldn't replicate the issue.
I guess it's caused by operator inconsistencies generated under specific shapes.
Could I get your suggestions on how to further investigate or debug this issue?


Pytorch reproducer

import os
import torch
import torch.nn as nn
from torch.testing._internal.common_utils import TestCase

cpu_device = torch.device("cpu")
dpcpp_device = torch.device("xpu")

a_size = int(os.environ.get('A_SIZE', 1))

# This reproducer will throw page fault when a_size=1~8

a=torch.rand([a_size, 1024], device='xpu:0', dtype=torch.float16)
b=torch.rand([1024, 1024], device='xpu:0', dtype=torch.float16)
torch.mm(a,b)

dtype=torch.double

x_cpu = torch.randn(
    [1, 64, 256, 256], dtype=dtype, device=cpu_device, requires_grad=True
)
grad_cpu = torch.full(
    [1, 64, 256, 256], 1e-3, dtype=dtype, device=cpu_device, requires_grad=True
)
conv_cpu = nn.Conv2d(
    64, 64, kernel_size=3, stride=1, padding=1, bias=False
).double()

x_dpcpp = x_cpu.to(dpcpp_device).requires_grad_()
grad_dpcpp = grad_cpu.to(dpcpp_device)
conv_dpcpp = conv_cpu.to(dpcpp_device)
y_dpcpp = conv_dpcpp(x_dpcpp)
y_dpcpp.backward(grad_dpcpp)

oneDNN Verbose

onednn_verbose,v1,info,oneDNN v3.7.0 (commit 5e9224036021433d2577548ed0539fe9a53256bc)
onednn_verbose,v1,info,cpu,runtime:threadpool,nthr:56
onednn_verbose,v1,info,cpu,isa:Intel AVX-512 with float16, Intel DL Boost and bfloat16 support 
onednn_verbose,v1,info,gpu,runtime:DPC++
onednn_verbose,v1,info,gpu,engine,sycl gpu device count:8 
onednn_verbose,v1,info,gpu,engine,0,backend:Level Zero,name:Intel(R) Data Center GPU Max 1550,driver_version:1.6.31294,binary_kernels:enabled
onednn_verbose,v1,info,gpu,engine,1,backend:Level Zero,name:Intel(R) Data Center GPU Max 1550,driver_version:1.6.31294,binary_kernels:enabled
onednn_verbose,v1,info,gpu,engine,2,backend:Level Zero,name:Intel(R) Data Center GPU Max 1550,driver_version:1.6.31294,binary_kernels:enabled
onednn_verbose,v1,info,gpu,engine,3,backend:Level Zero,name:Intel(R) Data Center GPU Max 1550,driver_version:1.6.31294,binary_kernels:enabled
onednn_verbose,v1,info,gpu,engine,4,backend:Level Zero,name:Intel(R) Data Center GPU Max 1550,driver_version:1.6.31294,binary_kernels:enabled
onednn_verbose,v1,info,gpu,engine,5,backend:Level Zero,name:Intel(R) Data Center GPU Max 1550,driver_version:1.6.31294,binary_kernels:enabled
onednn_verbose,v1,info,gpu,engine,6,backend:Level Zero,name:Intel(R) Data Center GPU Max 1550,driver_version:1.6.31294,binary_kernels:enabled
onednn_verbose,v1,info,gpu,engine,7,backend:Level Zero,name:Intel(R) Data Center GPU Max 1550,driver_version:1.6.31294,binary_kernels:enabled
onednn_verbose,v1,info,graph,backend,0:dnnl_backend
onednn_verbose,v1,primitive,info,template:operation,engine,primitive,implementation,prop_kind,memory_descriptors,attributes,auxiliary,problem_desc
onednn_verbose,v1,graph,info,template:operation,engine,partition_id,partition_kind,op_names,data_formats,logical_tensors,fpmath_mode,implementation,backend
onednn_verbose,v1,primitive,exec,gpu:0,matmul,jit:gemm:any,undef,src:f16::blocked:ab::f0 wei:f16::blocked:ab::f0 dst:f16::blocked:ab::f0,attr-scratchpad:user,,1x1024:1024x1024
onednn_verbose,v1,primitive,exec,gpu:0,convolution,jit:ir,forward_training,src:f64::blocked:abcd::f0 wei:f64::blocked:abcd::f0 bia:undef::undef::: dst:f64::blocked:abcd::f0,attr-scratchpad:user,alg:convolution_direct,mb1_ic64oc64_ih256oh256kh3sh1dh0ph1_iw256ow256kw3sw1dw0pw1
FATAL: Unexpected page fault from GPU at 0xff0000000e97a000, ctx_id: 1 (CCS) type: 0 (NotPresent), level: 1 (PDE), access: 0 (Read), banned: 1, aborting.
FATAL: Unexpected page fault from GPU at 0xff0000000e97a000, ctx_id: 1 (CCS) type: 0 (NotPresent), level: 1 (PDE), access: 0 (Read), banned: 1, aborting.
Abort was called at 288 line in file:
./shared/source/os_interface/linux/drm_neo.cpp
Aborted
@dzarukin
Copy link
Contributor

Hi @hoshibara, thank you for providing verbose logs, that gives some perspective and tips, but it if you could provide ONEDNN_VERBOSE=all output (on the contrary to 1) that may give some more context for the team triaging the issue.

If the problem was in the execution of backward path, it should report a create line for that backward convolution, and it would help to see if parameters matched and what implementation was used.

It was also mentioned it started after upgrade to v3.7. Does it imply the previous version was working well? And if so, which version was that?

And just in case if you are eager to help us with debugging, there are some more questions:

  • Does the issue happen if and only if there's this specific matmul in there? What happens if matmul op gets removed?
  • Does the issue happen with f16 matmul on any other shape? Like 1x256:256x256? Or maybe 16x16:16x16?
  • Does it happen with f32 or f64 matmul used instead of f16?
  • Does it happen for this specific convolution shape? Will any other fail?

Answering these questions will help to narrow down the scope and potentially simplify the standalone oneDNN reproducer.

Thank you.

@hoshibara
Copy link
Author

hoshibara commented Mar 13, 2025

Hi @dzarukin, thank you for your patience and suggestions.
I followed your advice and did further testing. Here are the results:

It was also mentioned it started after upgrade to v3.7. Does it imply the previous version was working well? And if so, which version was that?

I rollback the oneDNN integrated in the latest PyTorch and confirmed that this issue first appeared in v3.7 (5e92240).
When I switched to the v3.6.2 (2eb3dd1) branch, this test can be ran properly.
In addition, I collected ONEDNN_VERBOSE=all logs for both versions when running the reproducer. Follows are the differences:

diff --git a/shape_1.3_6_2.processed.log b/shape_1.3_7.processed.log
index d3bd144..25dad21 100644
--- a/shape_1.3_6_2.processed.log
+++ b/shape_1.3_7.processed.log
@@ -1,13 +1,14 @@
-onednn_verbose,v1,info,oneDNN v3.6.2 (commit 2eb3dd1082db767fab171e934c551c609008289a)
+onednn_verbose,v1,info,oneDNN v3.7.0 (commit 5e9224036021433d2577548ed0539fe9a53256bc)
 onednn_verbose,v1,info,cpu,runtime:threadpool,nthr:56
 onednn_verbose,v1,info,cpu,isa:Intel AVX-512 with float16, Intel DL Boost and bfloat16 support 
 onednn_verbose,v1,info,gpu,runtime:DPC++
+onednn_verbose,v1,info,gpu,engine,sycl gpu device count:1 
 onednn_verbose,v1,info,gpu,engine,0,backend:Level Zero,name:Intel(R) Data Center GPU Max 1550,driver_version:1.6.31294,binary_kernels:enabled
 onednn_verbose,v1,info,graph,backend,0:dnnl_backend
 onednn_verbose,v1,primitive,info,template:operation,engine,primitive,implementation,prop_kind,memory_descriptors,attributes,auxiliary,problem_desc,exec_time
 onednn_verbose,v1,graph,info,template:operation,engine,partition_id,partition_kind,op_names,data_formats,logical_tensors,fpmath_mode,implementation,backend,exec_time
-onednn_verbose,v1,primitive,create:dispatch,gemm,gpu,gemm,jit:xe_hp:gemm:any,undef,src_a:f16::blocked:ab::f0 src_b:f16::blocked:ab::f0 dst:f16::blocked:ab::f0,,,1x1024:1024x1024,unsupported format tag,src/gpu/intel/jit/gemm/xe_hp_systolic_gemm.cpp:69
-onednn_verbose,v1,primitive,create:dispatch,gemm,gpu,gemm,jit:xe_hp:gemm:any,undef,src_a:f16::blocked:ab::f0 src_b:f16::blocked:ab::f0 dst:f16::blocked:ab::f0,,,1x1024:1024x1024,unsupported format tag,src/gpu/intel/jit/gemm/xe_hp_systolic_gemm.cpp:69
+onednn_verbose,v1,primitive,create:dispatch,gemm,gpu,gemm,jit:xe_hp:gemm:any,undef,src_a:f16::blocked:ab::f0 src_b:f16::blocked:ab::f0 dst:f16::blocked:ab::f0,,,1x1024:1024x1024,unsupported format tag,src/gpu/intel/jit/gemm/xe_hp_systolic_gemm.cpp:80
+onednn_verbose,v1,primitive,create:dispatch,gemm,gpu,gemm,jit:xe_hp:gemm:any,undef,src_a:f16::blocked:ab::f0 src_b:f16::blocked:ab::f0 dst:f16::blocked:ab::f0,,,1x1024:1024x1024,unsupported format tag,src/gpu/intel/jit/gemm/xe_hp_systolic_gemm.cpp:80
 onednn_verbose,v1,info,gpu,gemm,consider:64x40,4x8x1,score:xxxxxx
 onednn_verbose,v1,info,gpu,gemm,consider:64x32,4x8x1,score:xxxxxx
 onednn_verbose,v1,info,gpu,gemm,consider:32x48,8x4x1,score:xxxxxx
@@ -33,7 +34,7 @@ onednn_verbose,v1,info,gpu,gemm,consider:64x32,4x8x1,score:xxxxxx
 onednn_verbose,v1,info,gpu,gemm,consider:64x2,2x8x1,score:xxxxxx
 onednn_verbose,v1,info,gpu,gemm,consider:64x4,8x2x1,score:xxxxxx
 onednn_verbose,v1,info,gpu,gemm,consider:64x4,8x2x1,score:xxxxxx
-onednn_verbose,v1,primitive,create:dispatch,gemm,gpu,gemm,ocl:gemm_with_po:any,undef,src_a:f16::blocked:ab::f0 src_b:f16::blocked:ab::f0 dst:f16::blocked:ab::f0,,,1x1024:1024x1024,failed to create nested primitive jit:gemm:any,src/gpu/intel/ocl/gemm/gemm_with_post_ops.cpp:76
+onednn_verbose,v1,primitive,create:dispatch,gemm,gpu,gemm,ocl:gemm_with_po:any,undef,src_a:f16::blocked:ab::f0 src_b:f16::blocked:ab::f0 dst:f16::blocked:ab::f0,,,1x1024:1024x1024,failed to create nested jit:gemm:any primitive,src/gpu/intel/ocl/gemm/gemm_with_post_ops.cpp:76
 onednn_verbose,v1,info,gpu,gemm,consider:64x40,4x8x1,score:xxxxxx
 onednn_verbose,v1,info,gpu,gemm,consider:64x32,4x8x1,score:xxxxxx
 onednn_verbose,v1,info,gpu,gemm,consider:32x48,8x4x1,score:xxxxxx
@@ -67,6 +68,7 @@ onednn_verbose,v1,primitive,exec,gpu,matmul,jit:gemm:any,undef,src:f16::blocked:
 onednn_verbose,v1,primitive,create:cache_miss,gpu,convolution,jit:ir,forward_training,src:f64::blocked:abcd::f0 wei:f64::blocked:abcd::f0 bia:undef::undef::: dst:f64::blocked:abcd::f0,attr-scratchpad:user,alg:convolution_direct,mb1_ic64oc64_ih256oh256kh3sh1dh0ph1_iw256ow256kw3sw1dw0pw1,xxxxxx
 onednn_verbose,v1,primitive,exec,gpu,convolution,jit:ir,forward_training,src:f64::blocked:abcd::f0 wei:f64::blocked:abcd::f0 bia:undef::undef::: dst:f64::blocked:abcd::f0,attr-scratchpad:user,alg:convolution_direct,mb1_ic64oc64_ih256oh256kh3sh1dh0ph1_iw256ow256kw3sw1dw0pw1,xxxxxx
 onednn_verbose,v1,primitive,create:cache_miss,gpu,convolution,jit:ir,backward_data,src:f64::blocked:abcd::f0 wei:f64::blocked:abcd::f0 bia:undef::undef::: dst:f64::blocked:abcd::f0,attr-scratchpad:user,alg:convolution_direct,mb1_ic64oc64_ih256oh256kh3sh1dh0ph1_iw256ow256kw3sw1dw0pw1,xxxxxx
-onednn_verbose,v1,primitive,exec,gpu,convolution,jit:ir,backward_data,src:f64::blocked:abcd::f0 wei:f64::blocked:abcd::f0 bia:undef::undef::: dst:f64::blocked:abcd::f0,attr-scratchpad:user,alg:convolution_direct,mb1_ic64oc64_ih256oh256kh3sh1dh0ph1_iw256ow256kw3sw1dw0pw1,xxxxxx
-onednn_verbose,v1,primitive,create:cache_miss,gpu,convolution,jit:ir,backward_weights,src:f64::blocked:abcd::f0 wei:f64::blocked:abcd::f0 bia:undef::undef::: dst:f64::blocked:abcd::f0,attr-scratchpad:user,alg:convolution_direct,mb1_ic64oc64_ih256oh256kh3sh1dh0ph1_iw256ow256kw3sw1dw0pw1,xxxxxx
-onednn_verbose,v1,primitive,exec,gpu,convolution,jit:ir,backward_weights,src:f64::blocked:abcd::f0 wei:f64::blocked:abcd::f0 bia:undef::undef::: dst:f64::blocked:abcd::f0,attr-scratchpad:user,alg:convolution_direct,mb1_ic64oc64_ih256oh256kh3sh1dh0ph1_iw256ow256kw3sw1dw0pw1,xxxxxx
+FATAL: Unexpected page fault from GPU at 0xff0000000ea3a000, ctx_id: 1 (CCS) type: 0 (NotPresent), level: 1 (PDE), access: 0 (Read), banned: 1, aborting.
+FATAL: Unexpected page fault from GPU at 0xff0000000ea3a000, ctx_id: 1 (CCS) type: 0 (NotPresent), level: 1 (PDE), access: 0 (Read), banned: 1, aborting.
+Abort was called at 288 line in file:
+./shared/source/os_interface/linux/drm_neo.cpp

Does the issue happen if and only if there's this specific matmul in there? What happens if matmul op gets removed?

Yes. If I remove the torch.mm operation in the code, conv2d runs without any errors.

Does the issue happen with f16 matmul on any other shape? Like 1x256:256x256? Or maybe 16x16:16x16?

Yes. The smallest shape I found that can reproduce the issue is (1~8)x8 : 8x4.

Does it happen with f32 or f64 matmul used instead of f16?

No. If I use f32 or f64 for matmul with the same tensor shapes, conv2d works fine and no error occurs.

Does it happen for this specific convolution shape? Will any other fail?

Yes. The smallest convolution shape I found so far that reproduces the issue is as follows:

x_cpu shape: [1, 32, 256, 256]
grad_cpu shape: [1, 9, 256, 256]
conv_cpu = nn.Conv2d(32, 9, kernel_size=3, stride=1, padding=1, bias=False).double()

Thank you again for your time and support!

@rcao8
Copy link

rcao8 commented Mar 13, 2025

Can't reproduce it on MTL iGPU for oneDNN v3.7.0, but reproduced on Intel(R) Data Center GPU Max 1100,driver_version:1.6.33022 with oneDNN v3.7.1 (commit 8d263e6)
onednn_verbose,v1,info,oneDNN v3.7.1 (commit 8d263e693366ef8db40acc569cc7d8edf644556d) onednn_verbose,v1,info,cpu,runtime:threadpool,nthr:32 onednn_verbose,v1,info,cpu,isa:Intel AVX-512 with float16, Intel DL Boost and bfloat16 support and Intel AMX with bfloat16 and 8-bit integer support onednn_verbose,v1,info,gpu,runtime:DPC++ onednn_verbose,v1,info,gpu,engine,sycl gpu device count:2 onednn_verbose,v1,info,gpu,engine,0,backend:Level Zero,name:Intel(R) Data Center GPU Max 1100,driver_version:1.6.33022,binary_kernels:enabled onednn_verbose,v1,info,gpu,engine,1,backend:Level Zero,name:Intel(R) Data Center GPU Max 1100,driver_version:1.6.33022,binary_kernels:enabled onednn_verbose,v1,info,graph,backend,0:dnnl_backend onednn_verbose,v1,primitive,info,template:operation,engine,primitive,implementation,prop_kind,memory_descriptors,attributes,auxiliary,problem_desc,exec_time onednn_verbose,v1,graph,info,template:operation,engine,partition_id,partition_kind,op_names,data_formats,logical_tensors,fpmath_mode,implementation,backend,exec_time onednn_verbose,v1,primitive,exec,gpu:0,matmul,jit:gemm:any,undef,src:f16::blocked:ab::f0 wei:f16::blocked:ab::f0 dst:f16::blocked:ab::f0,attr-scratchpad:user,,1x1024:1024x1024,0.73291 onednn_verbose,v1,primitive,exec,gpu:0,convolution,jit:ir,forward_training,src:f64::blocked:abcd::f0 wei:f64::blocked:abcd::f0 bia:undef::undef::: dst:f64::blocked:abcd::f0,attr-scratchpad:user,alg:convolution_direct,mb1_ic64oc64_ih256oh256kh3sh1dh0ph1_iw256ow256kw3sw1dw0pw1,0.8479 Segmentation fault from GPU at 0xff0000000e6f3000, ctx_id: 1 (CCS) type: 0 (NotPresent), level: 1 (PDE), access: 0 (Read), banned: 1, aborting. Segmentation fault from GPU at 0xff0000000e6f3000, ctx_id: 1 (CCS) type: 0 (NotPresent), level: 1 (PDE), access: 0 (Read), banned: 1, aborting. Abort was called at 276 line in file: /opt/src/compute-neo/shared/source/os_interface/linux/drm_neo.cpp Aborted (core dumped)

@vpirogov vpirogov added bug A confirmed library bug and removed question labels Mar 13, 2025
@dzarukin
Copy link
Contributor

@hoshibara Thank you for the data and the information.

@echeresh @kealan-barbieri Sounds like the real and not straightforward problem with conv bwd_d, could you, please, assist further?

@vpirogov vpirogov added the platform:gpu-intel Codeowner: @oneapi-src/onednn-gpu-intel label Mar 13, 2025
@echeresh
Copy link
Contributor

@hoshibara Thanks for the details. To rule out buffer size mismatch - can you please collect trace with onetrace tool?
You can build it from source: https://github.com/intel/pti-gpu/tree/master/tools/onetrace

After that you can run your application as:

onetrace -c <application>

It should trace all Level Zero API calls, and we can check all memory allocations to see if they have the correct sizes.

@hoshibara
Copy link
Author

hoshibara commented Mar 18, 2025

Hi @echeresh
Thanks for your help. Here are my results.

Command:

$ ZE_AFFINITY_MASK=2 onetrace -c python onednn-reproducer.py > onetrace.log 2>&1
[1]    1226761 IOT instruction (core dumped)  ZE_AFFINITY_MASK=2  -c python onednn-reproducer.py > onetrace.log 2>&1

Log:
onetrace.log

Also, I noticed a strange thing that may help in locating the problem.
The Unexpected page fault only appears on 1 tile on a particular GPU.

$ ZE_AFFINITY_MASK=3 python onednn-reproducer.py (Can properly finish)

$ ZE_AFFINITY_MASK=2 python onednn-reproducer.py (Unexpected page fault)
Traceback (most recent call last):
  File "/workspace1/xingyuan/20250313-onednn-pytorch-reproducer/onednn-reproducer.py", line 33, in <module>
    y_dpcpp.backward(grad_dpcpp)
  File "/workspace1/xingyuan/20250313-onednn-pytorch-reproducer/hoshibara-pytorch/torch/_tensor.py", line 648, in backward
    torch.autograd.backward(
  File "/workspace1/xingyuan/20250313-onednn-pytorch-reproducer/hoshibara-pytorch/torch/autograd/__init__.py", line 353, in backward
    _engine_run_backward(
  File "/workspace1/xingyuan/20250313-onednn-pytorch-reproducer/hoshibara-pytorch/torch/autograd/graph.py", line 824, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: UR backend failed. UR backend returns:40 (UR_RESULT_ERROR_OUT_OF_RESOURCES)
FATAL: Unexpected page fault from GPU at 0xff0000004bdf7000, ctx_id: 1 (CCS) type: 0 (NotPresent), level: 2 (PDP), access: 0 (Read), banned: 1, aborting.
FATAL: Unexpected page fault from GPU at 0xff0000004bdf7000, ctx_id: 1 (CCS) type: 0 (NotPresent), level: 2 (PDP), access: 0 (Read), banned: 1, aborting.
Abort was called at 288 line in file:
./shared/source/os_interface/linux/drm_neo.cpp
[1]    792005 IOT instruction (core dumped)  ZE_AFFINITY_MASK=2 python onednn-reproducer.py
$ sudo clinfo | grep GPU
  Device Name                                     Intel(R) Data Center GPU Max 1550
  Device Type                                     GPU
  Device Name                                     Intel(R) Data Center GPU Max 1550
  Device Type                                     GPU
  Device Name                                     Intel(R) Data Center GPU Max 1550
  Device Type                                     GPU
  Device Name                                     Intel(R) Data Center GPU Max 1550
  Device Type                                     GPU
  Device Name                                     Intel(R) Data Center GPU Max 1550
  Device Type                                     GPU
  Device Name                                     Intel(R) Data Center GPU Max 1550
  Device Type                                     GPU
  Device Name                                     Intel(R) Data Center GPU Max 1550
  Device Type                                     GPU
  Device Name                                     Intel(R) Data Center GPU Max 1550
  Device Type                                     GPU

@echeresh
Copy link
Contributor

echeresh commented Mar 18, 2025

@hoshibara Thanks for the log. I don't see any obvious buffer size mismatches. For reference, we have the following kernel sequence (also listing their arguments) at the end, before the page fault:

  • conv_reorder for weights OIHW -> HWIO
    • 0xff00ffffffc00800, 294912 bytes (input)
    • 0xff0000000a800000, 294912 bytes (output)
  • conv_reorder for destination gradient NCHW -> NHWC
    • 0xff00000004600000, 33554432 bytes (input)
    • 0xff0000000a848080, 33554432 bytes (output)
  • gen_conv - backward by data convolution
    • 0xff0000000a800000, 294912 bytes (weights, input)
    • 0xff0000000a848080, 33554432 bytes (destination gradient, input)
    • 0xff00000008600000, 33554432 bytes (source gradient, output)
  • zero_out - to zero out weights gradient for backward by weights convolution
    • 0xff0000000ea00080, 294912 bytes (output)

Based on the previous analysis, I assume the issue is in the first three kernels (they are related to backward by data convolution). What is interesting is that the faulting address 0xff0000000ec9b000 (page fault at read) is pretty far away from the preceding backward by data buffers, ~70 MB after 0xff0000000a848080

@hoshibara As for the next steps I suggest the following:

  • Collect a trace with ONEDNN_VERBOSE=2. Verbose mode adds synchronization and apparently the page fault appears earlier, before entering backward by weights. It'd be good to confirm that the faulting address is around the same (now it seems to be rather related to backward by weights which shouldn't be the case)
  • Run the test with ONEAPI_DEVICE_SELECTOR=opencl:* to switch to OpenCL backend. Please collect the trace with OpenCL backend as well. Level Zero traces are harder to analyze because of missing addresses in zeKernelSetArgumentValue so I had to guess some of the addresses above.
  • Also a question: is this a sporadic issue? If you see a page fault and run the same test many times (say 100) - does it show page fault at every run? There may be hardware-related issues but usually they are sporadic, sometimes pass, sometimes fail. As for passing/failing on different devices, it may be due to different allocation addresses (my best guess).

As for further debugging, we usually try some simple tricks:

  • Reducing the test case: e.g. now it includes three primitive calls, can we do one primitive? Reproducing the issue with benchdnn could be a great help as well
  • Disabling kernels: backward by data primitives includes three kernels, what if we skip some of them? This should point to the faulty kernel. This would require to update/re-build oneDNN and re-run the test with the updated library
  • Focusing on one buffer at a time. Apparently some allocations are hidden (and reused) in this PT example. We can check all primitive arguments: allocate a large dedicated buffer for the argument and pass it manually. If the issue disappears, this would point to a problematic buffer

@hoshibara
Copy link
Author

hoshibara commented Mar 19, 2025

Hi, @echeresh
Thank you for your response. I tried the suggestions you provided, and the results are as follows:

  1. ONEDNN_VERBOSE=2
    The command used was:
ONEDNN_VERBOSE=2 ZE_AFFINITY_MASK=0 onetrace -c python onednn-reproducer.py > onetrace.verbose-2.err.log 2>&1

The log is as follows:
onetrace.verbose-2.mask-0.err.log

In this log, I observed the page fault occurring during the gen_conv step.

  1. ONEAPI_DEVICE_SELECTOR=opencl:*
    This parameter causes python failing to find the XPU:
RuntimeError: No XPU devices are available.
  1. Is this a sporadic issue?
    No, the reproduction rate is as high as 94/100 attempts.
    If applying ONEDNN_VERBOSE=2, the reproduction rate increases to 100/100.

  2. benchdnn reproducer
    I try to use benchdnn, but it seems to only evaluate a single op.
    This case only appears on conv2d after the matmul.

  3. Disabling kernels
    You mean that we could try to disable the gen_conv step?

@echeresh
Copy link
Contributor

echeresh commented Mar 20, 2025

@hoshibara Thanks, the faulting address looks similar, ~35 MB after the scratchpad bounds.

ONEAPI_DEVICE_SELECTOR=opencl:*
This parameter causes python failing to find the XPU:

Okay, I guess PyTorch doesn't support OpenCL backend...

You mean that we could try to disable the gen_conv step?

Yes - disable gen_conv kernel or conv_reorder kernels. For that you would need to adjust oneDNN sources. If PyTorch links with oneDNN dynamically (libdnnl.so dependency) - I think we can easily debug it if it reproduces on our side.

So far, I have one suggestion only: can you please collect and share oneDNN kernel dumps (*.bin files) - for that you need to set ONEDNN_JIT_DUMP=1. I can compare binaries with the binaries on my side to see if kernels are different (they should be exactly the same).

If this doesn't clarify anything, as the next thing, I would try to reproduce and debug it on our side - for that can you please share links/steps you used to install and run PyTorch?

@hoshibara
Copy link
Author

@echeresh
Thanks! I’ll try to skip the gen_conv stage.

Here are my dump files:

ONEDNN_VERBOSE=2 ZE_AFFINITY_MASK=0 ONEDNN_JIT_DUMP=1 python -u onednn-reproducer.py

dump_onednn.zip

And my steps to install pytorch are as follow:

CONDA_ENV_NAME=onednn-reproducer
conda create -n $CONDA_ENV_NAME python=3.10 -y

conda deactivate||source deactivate||true
conda activate $CONDA_ENV_NAME||source activate $CONDA_ENV_NAME||true

ONEAPI_HOME=/home/sdp/intel/oneapi

source ${ONEAPI_HOME}/compiler/latest/env/vars.sh
source ${ONEAPI_HOME}/tbb/latest/env/vars.sh
source ${ONEAPI_HOME}/umf/latest/env/vars.sh
source ${ONEAPI_HOME}/pti/latest/env/vars.sh
source ${ONEAPI_HOME}/ccl/latest/env/vars.sh

export USE_XPU=1
export PYTORCH_ENABLE_XPU_FALLBACK=0
export PYTORCH_DEBUG_XPU_FALLBACK=1

export USE_KINETO=1

export BUILD_SEPARATE_OPS=ON

export USE_AOT_DEVLIST='pvc'
export TORCH_XPU_ARCH_LIST='pvc'

export _GLIBCXX_USE_CXX11_ABI=1
export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"}

git clone https://github.com/pytorch/pytorch.git pytorch
cd pytorch
git checkout 037d7af778dc26878307825deb559e9c767ed2ba
cd ..

# you can modify onednn version here:
# pytorch/cmake/Modules/FindMKLDNN.cmake:52

cd pytorch
pip uninstall torch -y
pip uninstall torch -y
pip uninstall torch -y
git submodule sync
git submodule update --init --recursive
pip install -r requirements.txt
make triton
python setup.py develop
cd ..

@rcao8
Copy link

rcao8 commented Mar 20, 2025

@echeresh Here is an another easy way to setup a Pytorch env and reproduce the issue.
Steps:

  1. setup venv
    python -m venv myenv

  2. active env
    source myenv/bin/activate

  3. install pytorch
    pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu

  4. run the reproducer
    DNNL_VERBOSE=1 python pytorch-reproducer.py #this pytorch-reproducer.py file comes from the post by @hoshibara

@rcao8
Copy link

rcao8 commented Mar 20, 2025

Can't reproduce the issue by running matmul and then running conv ops, for example:

int main() {
    // Run matmul on GPU
    run_matmul(engine::kind::gpu);

    // Run convolution on GPU
    run_convolution(engine::kind::gpu);

    std::cout << "oneDNN test cases executed successfully." << std::endl;
    return 0;
}

@dzarukin
Copy link
Contributor

Didn't reproduce the issue by running matmul and then running conv ops, for example:

int main() {
    // Run matmul on GPU
    run_matmul(engine::kind::gpu);

    // Run convolution on GPU
    run_convolution(engine::kind::gpu);

    std::cout << "oneDNN test cases executed successfully." << std::endl;
    return 0;
}

You'd need two convolutions at least, on forward and one backward data (and maybe one backward weights as well), according to the messages above.

@echeresh
Copy link
Contributor

echeresh commented Mar 21, 2025

@hoshibara @rcao8 Thanks. I was able to reproduce it and reduce to two primitives: MatMul (called once) and deconvolution (in a loop). The loop part has only one kernel (deconvolution) and doesn't have any kernel creations or memory allocations beyond the first iteration. So we basically execute the same kernel with the same buffers over and over again. The fact that it occasionally generates a page fault after some time points to a hardware bug. I'll work on a C++ reproducer next week to try different things to understand what triggers this behavior.

import torch
import torch.nn as nn

cpu_device = torch.device("cpu")
dpcpp_device = torch.device("xpu")

O, I, H, W = 128, 128, 128, 128

def mm():
    m, n = 1, 100
    a = torch.rand([m, n], device=dpcpp_device, dtype=torch.float16)
    b = torch.rand([n, n], device=dpcpp_device, dtype=torch.float16)
    torch.mm(a, b)

conv_type=torch.double
x_cpu = torch.randn([1, O, H, W], dtype=conv_type, device=cpu_device)
w_cpu = torch.randn([O, I, 1, 1], dtype=conv_type, device=cpu_device)
x_cpu = x_cpu.to(memory_format=torch.channels_last)
deconv_cpu = nn.ConvTranspose2d(O, I, kernel_size=1, stride=1, padding=0, bias=False).double()
x_dpcpp = x_cpu.to(dpcpp_device)
deconv_dpcpp = deconv_cpu.to(dpcpp_device)

mm()
for i in range(100):
    y_dpcpp = deconv_dpcpp(x_dpcpp)

Page fault:

$ ONEDNN_VERBOSE=1 python 1.py
onednn_verbose,v1,info,oneDNN v3.7.1 (commit 8d263e693366ef8db40acc569cc7d8edf644556d)
onednn_verbose,v1,info,cpu,runtime:threadpool,nthr:52
onednn_verbose,v1,info,cpu,isa:Intel AVX-512 with float16, Intel DL Boost and bfloat16 support
onednn_verbose,v1,info,gpu,runtime:DPC++
onednn_verbose,v1,info,gpu,engine,sycl gpu device count:1
onednn_verbose,v1,info,gpu,engine,0,backend:Level Zero,name:Intel(R) Data Center GPU Max 1550,driver_version:1.3.30049,binary_kernels:enabled
onednn_verbose,v1,info,graph,backend,0:dnnl_backend
onednn_verbose,v1,primitive,info,template:operation,engine,primitive,implementation,prop_kind,memory_descriptors,attributes,auxiliary,problem_desc,exec_time
onednn_verbose,v1,graph,info,template:operation,engine,partition_id,partition_kind,op_names,data_formats,logical_tensors,fpmath_mode,implementation,backend,exec_time
onednn_verbose,v1,primitive,exec,gpu,matmul,jit:gemm:any,undef,src:f16::blocked:ab::f0 wei:f16::blocked:ab::f0 dst:f16::blocked:ab::f0,attr-scratchpad:user,,1x100:100x100,0.215088
onednn_verbose,v1,primitive,exec,gpu,deconvolution,conv:any+jit:ir,forward_training,src:f64::blocked:acdb::f0 wei:f64::blocked:bacd::f0 bia:undef::undef::: dst:f64::blocked:acdb::f0,attr-scratchpad:user,alg:deconvolution_direct,mb1_ic128oc128_ih128oh128kh1sh1dh0ph0_iw128ow128kw1sw1dw0pw0,0.203125
onednn_verbose,v1,primitive,exec,gpu,deconvolution,conv:any+jit:ir,forward_training,src:f64::blocked:acdb::f0 wei:f64::blocked:bacd::f0 bia:undef::undef::: dst:f64::blocked:acdb::f0,attr-scratchpad:user,alg:deconvolution_direct,mb1_ic128oc128_ih128oh128kh1sh1dh0ph0_iw128ow128kw1sw1dw0pw0,0.0830078
onednn_verbose,v1,primitive,exec,gpu,deconvolution,conv:any+jit:ir,forward_training,src:f64::blocked:acdb::f0 wei:f64::blocked:bacd::f0 bia:undef::undef::: dst:f64::blocked:acdb::f0,attr-scratchpad:user,alg:deconvolution_direct,mb1_ic128oc128_ih128oh128kh1sh1dh0ph0_iw128ow128kw1sw1dw0pw0,0.0688477
onednn_verbose,v1,primitive,exec,gpu,deconvolution,conv:any+jit:ir,forward_training,src:f64::blocked:acdb::f0 wei:f64::blocked:bacd::f0 bia:undef::undef::: dst:f64::blocked:acdb::f0,attr-scratchpad:user,alg:deconvolution_direct,mb1_ic128oc128_ih128oh128kh1sh1dh0ph0_iw128ow128kw1sw1dw0pw0,0.0671387
onednn_verbose,v1,primitive,exec,gpu,deconvolution,conv:any+jit:ir,forward_training,src:f64::blocked:acdb::f0 wei:f64::blocked:bacd::f0 bia:undef::undef::: dst:f64::blocked:acdb::f0,attr-scratchpad:user,alg:deconvolution_direct,mb1_ic128oc128_ih128oh128kh1sh1dh0ph0_iw128ow128kw1sw1dw0pw0,0.0849609
onednn_verbose,v1,primitive,exec,gpu,deconvolution,conv:any+jit:ir,forward_training,src:f64::blocked:acdb::f0 wei:f64::blocked:bacd::f0 bia:undef::undef::: dst:f64::blocked:acdb::f0,attr-scratchpad:user,alg:deconvolution_direct,mb1_ic128oc128_ih128oh128kh1sh1dh0ph0_iw128ow128kw1sw1dw0pw0,0.103027
onednn_verbose,v1,primitive,exec,gpu,deconvolution,conv:any+jit:ir,forward_training,src:f64::blocked:acdb::f0 wei:f64::blocked:bacd::f0 bia:undef::undef::: dst:f64::blocked:acdb::f0,attr-scratchpad:user,alg:deconvolution_direct,mb1_ic128oc128_ih128oh128kh1sh1dh0ph0_iw128ow128kw1sw1dw0pw0,0.064209
onednn_verbose,v1,primitive,exec,gpu,deconvolution,conv:any+jit:ir,forward_training,src:f64::blocked:acdb::f0 wei:f64::blocked:bacd::f0 bia:undef::undef::: dst:f64::blocked:acdb::f0,attr-scratchpad:user,alg:deconvolution_direct,mb1_ic128oc128_ih128oh128kh1sh1dh0ph0_iw128ow128kw1sw1dw0pw0,0.0620117
onednn_verbose,v1,primitive,exec,gpu,deconvolution,conv:any+jit:ir,forward_training,src:f64::blocked:acdb::f0 wei:f64::blocked:bacd::f0 bia:undef::undef::: dst:f64::blocked:acdb::f0,attr-scratchpad:user,alg:deconvolution_direct,mb1_ic128oc128_ih128oh128kh1sh1dh0ph0_iw128ow128kw1sw1dw0pw0,0.0629883
onednn_verbose,v1,primitive,exec,gpu,deconvolution,conv:any+jit:ir,forward_training,src:f64::blocked:acdb::f0 wei:f64::blocked:bacd::f0 bia:undef::undef::: dst:f64::blocked:acdb::f0,attr-scratchpad:user,alg:deconvolution_direct,mb1_ic128oc128_ih128oh128kh1sh1dh0ph0_iw128ow128kw1sw1dw0pw0,0.996826
FATAL: Unexpected page fault from GPU at 0xff00fffffff9e000, ctx_id: 1 (CCS) type: 0 (NotPresent), level: 1 (PDE), access: 0 (Read), banned: 1, aborting.
FATAL: Unexpected page fault from GPU at 0xff00fffffff9e000, ctx_id: 1 (CCS) type: 0 (NotPresent), level: 1 (PDE), access: 0 (Read), banned: 1, aborting.
Abort was called at 287 line in file:
./shared/source/os_interface/linux/drm_neo.cpp
Aborted (core dumped)

@echeresh
Copy link
Contributor

@hoshibara @rcao8 The preliminary conclusion is that this is caused by a known double data type related hardware bug in PVC. I was able to reproduce it in benchdnn so it's completely on our side from this point - we'll investigate it and prepare a fix. A fix should be available in some future oneDNN release.

As a mitigation, I see that overriding the thread arbitration policy makes the issue non-reproducible but it may come with a performance cost and may still produce the same issue but at a more rare rate:

$ OverrideThreadArbitrationPolicy=2 NEOReadDebugKeys=1 PrintDebugSettings=1 python 1.py
Non-default value of debug variable: PrintDebugSettings = 1
Non-default value of debug variable: OverrideThreadArbitrationPolicy = 2
...

@petercad
Copy link
Contributor

petercad commented Mar 22, 2025

@hoshibara @rcao8 We were able to confirm that the PVC HW bug @echeresh mentioned is the culprit. The linked PR (#2935) should fix the issue you're seeing.

@hoshibara
Copy link
Author

@petercad @echeresh Thanks a lot for confirming and fixing this issue.
I have verified that the latest oneDNN in PyTorch no longer triggers this issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug A confirmed library bug platform:gpu-intel Codeowner: @oneapi-src/onednn-gpu-intel
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants