File tree 1 file changed +2
-16
lines changed
1 file changed +2
-16
lines changed Original file line number Diff line number Diff line change 9
9
from torch .nn import CrossEntropyLoss , MSELoss , BCEWithLogitsLoss
10
10
from torch .nn import functional as F
11
11
12
- from transformers .activations import ACT2FN
12
+ from transformers .activations import get_activation
13
13
from transformers .modeling_outputs import (
14
14
BaseModelOutputWithPast ,
15
15
CausalLMOutputWithPast ,
33
33
# ]
34
34
35
35
36
- @torch .jit .script
37
- def NewGELU (x ):
38
- """
39
- Compute the NewGELU activation function.
40
-
41
- Args:
42
- x (torch.Tensor): Input tensor.
43
-
44
- Returns:
45
- torch.Tensor: Output tensor after applying NewGELU activation.
46
- """
47
- return 0.5 * x * (1.0 + torch .tanh (math .sqrt (2.0 / math .pi ) * (x + 0.044715 * torch .pow (x , 3.0 ))))
48
-
49
-
50
36
@torch .jit .script
51
37
def stickbreaking_att (
52
38
q : torch .Tensor ,
@@ -230,7 +216,7 @@ def __init__(self, config):
230
216
num_experts = config .n_mlp_experts ,
231
217
top_k = config .k_mlp ,
232
218
bias = False ,
233
- activation = NewGELU ,
219
+ activation = get_activation ( config . activation_function ) ,
234
220
acc_aux_loss = False ,
235
221
gating_dropout = config .moe_pdrop ,
236
222
sample_topk = config .sample_topk ,
You can’t perform that action at this time.
0 commit comments