Skip to content

Commit 096a3ea

Browse files
committed
implement transformers activation functions
1 parent 3824d86 commit 096a3ea

File tree

1 file changed

+2
-16
lines changed

1 file changed

+2
-16
lines changed

moduleformer/modeling_moduleformer.py

+2-16
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
1010
from torch.nn import functional as F
1111

12-
from transformers.activations import ACT2FN
12+
from transformers.activations import get_activation
1313
from transformers.modeling_outputs import (
1414
BaseModelOutputWithPast,
1515
CausalLMOutputWithPast,
@@ -33,20 +33,6 @@
3333
# ]
3434

3535

36-
@torch.jit.script
37-
def NewGELU(x):
38-
"""
39-
Compute the NewGELU activation function.
40-
41-
Args:
42-
x (torch.Tensor): Input tensor.
43-
44-
Returns:
45-
torch.Tensor: Output tensor after applying NewGELU activation.
46-
"""
47-
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
48-
49-
5036
@torch.jit.script
5137
def stickbreaking_att(
5238
q: torch.Tensor,
@@ -230,7 +216,7 @@ def __init__(self, config):
230216
num_experts=config.n_mlp_experts,
231217
top_k=config.k_mlp,
232218
bias=False,
233-
activation=NewGELU,
219+
activation=get_activation(config.activation_function),
234220
acc_aux_loss=False,
235221
gating_dropout=config.moe_pdrop,
236222
sample_topk=config.sample_topk,

0 commit comments

Comments
 (0)