-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdecoder_1.py
26 lines (22 loc) · 1.16 KB
/
decoder_1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch.nn as nn
class TransformerDecoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=16, dropout=0):
super(TransformerDecoderLayer, self).__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = nn.LeakyReLU(True)
def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None,
memory_key_padding_mask=None):
tgt2 = self.self_attn(tgt, tgt, tgt)[0]
tgt = tgt + self.dropout1(tgt2)
tgt2 = self.multihead_attn(tgt, memory, memory)[0]
tgt = tgt + self.dropout2(tgt2)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout3(tgt2)
return tgt