Skip to content

Commit 3824d86

Browse files
committed
pass gradient_checkpointing_kwargs to methods
1 parent 5ced638 commit 3824d86

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

moduleformer/modeling_moduleformer.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -333,17 +333,26 @@ def _init_weights(self, module):
333333
module.bias.data.zero_()
334334
module.weight.data.fill_(1.0)
335335

336-
def gradient_checkpointing_enable(self):
336+
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs={}):
337337
for module in self.modules():
338338
if hasattr(module, "gradient_checkpointing"):
339-
self._set_gradient_checkpointing(module, True)
339+
self._set_gradient_checkpointing(
340+
module, True, gradient_checkpointing_kwargs
341+
)
340342

341343
def gradient_checkpointing_disable(self):
342344
for module in self.modules():
343345
if hasattr(module, "gradient_checkpointing"):
344-
self._set_gradient_checkpointing(module, False)
346+
self._set_gradient_checkpointing(
347+
module, False
348+
)
345349

346-
def _set_gradient_checkpointing(self, module, value=False):
350+
def _set_gradient_checkpointing(
351+
self,
352+
module,
353+
value=False,
354+
gradient_checkpointing_kwargs={"use_reentrant": False},
355+
):
347356
"""
348357
Set gradient checkpointing for the ModuleFormerModel.
349358
@@ -353,6 +362,7 @@ def _set_gradient_checkpointing(self, module, value=False):
353362
"""
354363
if isinstance(module, ModuleFormerModel):
355364
module.gradient_checkpointing = value
365+
module.gradient_checkpointing_kwargs = gradient_checkpointing_kwargs
356366

357367

358368
SPARSEGPT_START_DOCSTRING = r"""
@@ -554,6 +564,7 @@ def custom_forward(*inputs):
554564
None,
555565
attention_mask,
556566
head_mask[i],
567+
**self.gradient_checkpointing_kwargs,
557568
)
558569
else:
559570
outputs = block(

0 commit comments

Comments
 (0)