@@ -333,17 +333,26 @@ def _init_weights(self, module):
333
333
module .bias .data .zero_ ()
334
334
module .weight .data .fill_ (1.0 )
335
335
336
- def gradient_checkpointing_enable (self ):
336
+ def gradient_checkpointing_enable (self , gradient_checkpointing_kwargs = {} ):
337
337
for module in self .modules ():
338
338
if hasattr (module , "gradient_checkpointing" ):
339
- self ._set_gradient_checkpointing (module , True )
339
+ self ._set_gradient_checkpointing (
340
+ module , True , gradient_checkpointing_kwargs
341
+ )
340
342
341
343
def gradient_checkpointing_disable (self ):
342
344
for module in self .modules ():
343
345
if hasattr (module , "gradient_checkpointing" ):
344
- self ._set_gradient_checkpointing (module , False )
346
+ self ._set_gradient_checkpointing (
347
+ module , False
348
+ )
345
349
346
- def _set_gradient_checkpointing (self , module , value = False ):
350
+ def _set_gradient_checkpointing (
351
+ self ,
352
+ module ,
353
+ value = False ,
354
+ gradient_checkpointing_kwargs = {"use_reentrant" : False },
355
+ ):
347
356
"""
348
357
Set gradient checkpointing for the ModuleFormerModel.
349
358
@@ -353,6 +362,7 @@ def _set_gradient_checkpointing(self, module, value=False):
353
362
"""
354
363
if isinstance (module , ModuleFormerModel ):
355
364
module .gradient_checkpointing = value
365
+ module .gradient_checkpointing_kwargs = gradient_checkpointing_kwargs
356
366
357
367
358
368
SPARSEGPT_START_DOCSTRING = r"""
@@ -554,6 +564,7 @@ def custom_forward(*inputs):
554
564
None ,
555
565
attention_mask ,
556
566
head_mask [i ],
567
+ ** self .gradient_checkpointing_kwargs ,
557
568
)
558
569
else :
559
570
outputs = block (
0 commit comments