You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
If weight is frozen (for eg. in LoRA setting), we can avoid quantizing grad_output in columnwise direction.
Using following patch seems to work
diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py
index 83dc652c..3e58700a 100644
--- a/transformer_engine/pytorch/module/linear.py+++ b/transformer_engine/pytorch/module/linear.py@@ -423,6 +423,11 @@ class _Linear(torch.autograd.Function):
ub_obj_wgrad.set_buffer_params(ctx.grad_input_quantizer)
dgrad_bulk = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer)
+ if not ctx.requires_dgrad and ctx.grad_output_quantizer is not None:+ ctx.grad_output_quantizer.set_usage(rowwise=False)+ if not ctx.requires_wgrad and ctx.grad_output_quantizer is not None:+ ctx.grad_output_quantizer.set_usage(columnwise=False)+
# Prepare grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
if ctx.grad_output_quantizer is not None:
Test Script
importtorchimporttransformer_enginefromtransformer_engine.pytorchimportfp8_autocast, Lineardim=1024*22# Large input for demonstration of memory change.linear=Linear(dim, dim, bias=False)
x=torch.randn(dim, dim, requires_grad=True, device="cuda")
linear.weight.requires_grad=Falsewithfp8_autocast():
o=linear(x)
g_o=torch.randn_like(o)
o.backward(g_o)
# Without patch - 12314.476544 MB# With patch - 11790.188544 MBprint(torch.cuda.max_memory_allocated() /1e6, "MB")
The text was updated successfully, but these errors were encountered:
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html#Handling-transposes
Based on the diagram,
If weight is frozen (for eg. in LoRA setting), we can avoid quantizing
grad_output
in columnwise direction.Using following patch seems to work
Test Script
The text was updated successfully, but these errors were encountered: