Skip to content

[MXFP8] grad_output is quantized columnwise even if weight doesn't require gradients. #1693

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
kshitij12345 opened this issue Apr 17, 2025 · 2 comments · May be fixed by #1736
Open

[MXFP8] grad_output is quantized columnwise even if weight doesn't require gradients. #1693

kshitij12345 opened this issue Apr 17, 2025 · 2 comments · May be fixed by #1736
Labels
bug Something isn't working

Comments

@kshitij12345
Copy link

https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html#Handling-transposes

Based on the diagram,

Image

If weight is frozen (for eg. in LoRA setting), we can avoid quantizing grad_output in columnwise direction.

Using following patch seems to work

diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py
index 83dc652c..3e58700a 100644
--- a/transformer_engine/pytorch/module/linear.py
+++ b/transformer_engine/pytorch/module/linear.py
@@ -423,6 +423,11 @@ class _Linear(torch.autograd.Function):
                     ub_obj_wgrad.set_buffer_params(ctx.grad_input_quantizer)
                     dgrad_bulk = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer)
 
+            if not ctx.requires_dgrad and ctx.grad_output_quantizer is not None:
+                ctx.grad_output_quantizer.set_usage(rowwise=False)
+            if not ctx.requires_wgrad and ctx.grad_output_quantizer is not None:
+                ctx.grad_output_quantizer.set_usage(columnwise=False)
+
             # Prepare grad output tensor
             # Note: Cast to expected dtype and perform tensor-parallel communication
             if ctx.grad_output_quantizer is not None:

Test Script

import torch

import transformer_engine
from transformer_engine.pytorch import fp8_autocast, Linear

dim = 1024 * 22  # Large input for demonstration of memory change.
linear = Linear(dim, dim, bias=False)
x = torch.randn(dim, dim, requires_grad=True, device="cuda")

linear.weight.requires_grad = False

with fp8_autocast():
    o = linear(x)
    g_o = torch.randn_like(o)

o.backward(g_o)

# Without patch - 12314.476544 MB
# With patch - 11790.188544 MB
print(torch.cuda.max_memory_allocated() / 1e6, "MB")
@kshitij12345 kshitij12345 added the bug Something isn't working label Apr 17, 2025
@ptrendx
Copy link
Member

ptrendx commented Apr 25, 2025

Hi @kshitij12345 this makes perfect sense. The proposed solution looks good to me, could you create a PR with it?

@kshitij12345
Copy link
Author

Sure, I will have a PR up soon, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants