9
9
from torch .nn import CrossEntropyLoss , MSELoss , BCEWithLogitsLoss
10
10
from torch .nn import functional as F
11
11
12
- from transformers .activations import ACT2FN
12
+ from transformers .activations import get_activation
13
13
from transformers .modeling_outputs import (
14
14
BaseModelOutputWithPast ,
15
15
CausalLMOutputWithPast ,
33
33
# ]
34
34
35
35
36
- @torch .jit .script
37
- def NewGELU (x ):
38
- """
39
- Compute the NewGELU activation function.
40
-
41
- Args:
42
- x (torch.Tensor): Input tensor.
43
-
44
- Returns:
45
- torch.Tensor: Output tensor after applying NewGELU activation.
46
- """
47
- return 0.5 * x * (1.0 + torch .tanh (math .sqrt (2.0 / math .pi ) * (x + 0.044715 * torch .pow (x , 3.0 ))))
48
-
49
-
50
36
@torch .jit .script
51
37
def stickbreaking_att (
52
38
q : torch .Tensor ,
@@ -230,7 +216,7 @@ def __init__(self, config):
230
216
num_experts = config .n_mlp_experts ,
231
217
top_k = config .k_mlp ,
232
218
bias = False ,
233
- activation = NewGELU ,
219
+ activation = get_activation ( config . activation_function ) ,
234
220
acc_aux_loss = False ,
235
221
gating_dropout = config .moe_pdrop ,
236
222
sample_topk = config .sample_topk ,
@@ -333,17 +319,26 @@ def _init_weights(self, module):
333
319
module .bias .data .zero_ ()
334
320
module .weight .data .fill_ (1.0 )
335
321
336
- def gradient_checkpointing_enable (self ):
322
+ def gradient_checkpointing_enable (self , gradient_checkpointing_kwargs = {} ):
337
323
for module in self .modules ():
338
324
if hasattr (module , "gradient_checkpointing" ):
339
- self ._set_gradient_checkpointing (module , True )
325
+ self ._set_gradient_checkpointing (
326
+ module , True , gradient_checkpointing_kwargs
327
+ )
340
328
341
329
def gradient_checkpointing_disable (self ):
342
330
for module in self .modules ():
343
331
if hasattr (module , "gradient_checkpointing" ):
344
- self ._set_gradient_checkpointing (module , False )
332
+ self ._set_gradient_checkpointing (
333
+ module , False
334
+ )
345
335
346
- def _set_gradient_checkpointing (self , module , value = False ):
336
+ def _set_gradient_checkpointing (
337
+ self ,
338
+ module ,
339
+ value = False ,
340
+ gradient_checkpointing_kwargs = {"use_reentrant" : False },
341
+ ):
347
342
"""
348
343
Set gradient checkpointing for the ModuleFormerModel.
349
344
@@ -353,6 +348,7 @@ def _set_gradient_checkpointing(self, module, value=False):
353
348
"""
354
349
if isinstance (module , ModuleFormerModel ):
355
350
module .gradient_checkpointing = value
351
+ module .gradient_checkpointing_kwargs = gradient_checkpointing_kwargs
356
352
357
353
358
354
SPARSEGPT_START_DOCSTRING = r"""
@@ -554,6 +550,7 @@ def custom_forward(*inputs):
554
550
None ,
555
551
attention_mask ,
556
552
head_mask [i ],
553
+ ** self .gradient_checkpointing_kwargs ,
557
554
)
558
555
else :
559
556
outputs = block (
0 commit comments