Skip to content

Commit 850f1eb

Browse files
authored
Merge pull request #8 from LuciferianInk/enable-gradient-checkpointing
Implement transformers activation functions
2 parents 8e6f929 + 096a3ea commit 850f1eb

File tree

1 file changed

+17
-20
lines changed

1 file changed

+17
-20
lines changed

moduleformer/modeling_moduleformer.py

+17-20
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
1010
from torch.nn import functional as F
1111

12-
from transformers.activations import ACT2FN
12+
from transformers.activations import get_activation
1313
from transformers.modeling_outputs import (
1414
BaseModelOutputWithPast,
1515
CausalLMOutputWithPast,
@@ -33,20 +33,6 @@
3333
# ]
3434

3535

36-
@torch.jit.script
37-
def NewGELU(x):
38-
"""
39-
Compute the NewGELU activation function.
40-
41-
Args:
42-
x (torch.Tensor): Input tensor.
43-
44-
Returns:
45-
torch.Tensor: Output tensor after applying NewGELU activation.
46-
"""
47-
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
48-
49-
5036
@torch.jit.script
5137
def stickbreaking_att(
5238
q: torch.Tensor,
@@ -230,7 +216,7 @@ def __init__(self, config):
230216
num_experts=config.n_mlp_experts,
231217
top_k=config.k_mlp,
232218
bias=False,
233-
activation=NewGELU,
219+
activation=get_activation(config.activation_function),
234220
acc_aux_loss=False,
235221
gating_dropout=config.moe_pdrop,
236222
sample_topk=config.sample_topk,
@@ -333,17 +319,26 @@ def _init_weights(self, module):
333319
module.bias.data.zero_()
334320
module.weight.data.fill_(1.0)
335321

336-
def gradient_checkpointing_enable(self):
322+
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs={}):
337323
for module in self.modules():
338324
if hasattr(module, "gradient_checkpointing"):
339-
self._set_gradient_checkpointing(module, True)
325+
self._set_gradient_checkpointing(
326+
module, True, gradient_checkpointing_kwargs
327+
)
340328

341329
def gradient_checkpointing_disable(self):
342330
for module in self.modules():
343331
if hasattr(module, "gradient_checkpointing"):
344-
self._set_gradient_checkpointing(module, False)
332+
self._set_gradient_checkpointing(
333+
module, False
334+
)
345335

346-
def _set_gradient_checkpointing(self, module, value=False):
336+
def _set_gradient_checkpointing(
337+
self,
338+
module,
339+
value=False,
340+
gradient_checkpointing_kwargs={"use_reentrant": False},
341+
):
347342
"""
348343
Set gradient checkpointing for the ModuleFormerModel.
349344
@@ -353,6 +348,7 @@ def _set_gradient_checkpointing(self, module, value=False):
353348
"""
354349
if isinstance(module, ModuleFormerModel):
355350
module.gradient_checkpointing = value
351+
module.gradient_checkpointing_kwargs = gradient_checkpointing_kwargs
356352

357353

358354
SPARSEGPT_START_DOCSTRING = r"""
@@ -554,6 +550,7 @@ def custom_forward(*inputs):
554550
None,
555551
attention_mask,
556552
head_mask[i],
553+
**self.gradient_checkpointing_kwargs,
557554
)
558555
else:
559556
outputs = block(

0 commit comments

Comments
 (0)