Skip to content

Commit fa78973

Browse files
author
GJ98
committed
pad mask issue and embedding issue
1 parent db8c1e3 commit fa78973

32 files changed

+33
-26
lines changed

__pycache__/conf.cpython-37.pyc

708 Bytes
Binary file not shown.

__pycache__/data.cpython-37.pyc

937 Bytes
Binary file not shown.
270 Bytes
Binary file not shown.
277 Bytes
Binary file not shown.
Binary file not shown.
Binary file not shown.

models/blocks/decoder_layer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def __init__(self, d_model, ffn_hidden, n_head, drop_prob):
2929
def forward(self, dec, enc, t_mask, s_mask):
3030
# 1. compute self attention
3131
_x = dec
32-
x = self.self_attention(q=dec, k=dec, v=dec, mask=trg_mask)
32+
x = self.self_attention(q=dec, k=dec, v=dec, mask=t_mask)
3333

3434
# 2. add and norm
3535
x = self.norm1(x + _x)
@@ -38,7 +38,7 @@ def forward(self, dec, enc, t_mask, s_mask):
3838
if enc is not None:
3939
# 3. compute encoder - decoder attention
4040
_x = x
41-
x = self.enc_dec_attention(q=x, k=enc, v=enc, mask=src_mask)
41+
x = self.enc_dec_attention(q=x, k=enc, v=enc, mask=s_mask)
4242

4343
# 4. add and norm
4444
x = self.norm2(x + _x)

models/blocks/encoder_layer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def __init__(self, d_model, ffn_hidden, n_head, drop_prob):
2525
def forward(self, x, s_mask):
2626
# 1. compute self attention
2727
_x = x
28-
x = self.attention(q=x, k=x, v=x, mask=src_mask)
28+
x = self.attention(q=x, k=x, v=x, mask=s_mask)
2929

3030
# 2. add and norm
3131
x = self.norm1(x + _x)
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

models/embedding/token_embeddings.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,4 @@ class for token embedding that included positional information
1919
:param vocab_size: size of vocabulary
2020
:param d_model: dimensions of model
2121
"""
22-
super(TokenEmbedding, self).__init__(vocab_size, d_model, padding_idx=0)
22+
super(TokenEmbedding, self).__init__(vocab_size, d_model, padding_idx=1)
277 Bytes
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

models/layers/scale_dot_product_attention.py

+2
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ def forward(self, q, k, v, mask=None, e=1e-12):
3030
k_t = k.view(batch_size, head, d_tensor, length) # transpose
3131
score = (q @ k_t) / math.sqrt(d_tensor) # scaled dot product
3232

33+
print("score : {}" .format(score.size()))
34+
print("mask : {}" .format(mask.size()))
3335
# 2. apply masking (opt)
3436
if mask is not None:
3537
score = score.masked_fill(mask == 0, -e)
276 Bytes
Binary file not shown.
1.45 KB
Binary file not shown.
1.36 KB
Binary file not shown.
Binary file not shown.

models/model/transformer.py

+23-17
Original file line numberDiff line numberDiff line change
@@ -38,31 +38,37 @@ def __init__(self, src_pad_idx, trg_pad_idx, trg_sos_idx, enc_voc_size, dec_voc_
3838
device=device)
3939

4040
def forward(self, src, trg):
41-
src_mask = self.make_src_mask(src)
42-
trg_mask = self.make_trg_mask(trg)
41+
src_mask = self.make_pad_mask(src, src)
42+
43+
src_trg_mask = self.make_pad_mask(trg, src)
44+
45+
trg_mask = self.make_pad_mask(trg, trg) * \
46+
self.make_no_peak_mask(trg, trg)
47+
4348
enc_src = self.encoder(src, src_mask)
44-
output = self.decoder(trg, enc_src, trg_mask, src_mask)
49+
output = self.decoder(trg, enc_src, trg_mask, src_trg_mask)
4550
return output
4651

47-
def make_src_mask(self, src):
48-
batch_size, length = src.size()
52+
def make_pad_mask(self, q, k):
53+
len_q, len_k = q.size(1), k.size(1)
4954

5055
# batch_size x 1 x 1 x len_k
51-
src_k = src.ne(self.src_pad_idx).unsqueeze(1).unsqueeze(2)
56+
k = k.ne(self.src_pad_idx).unsqueeze(1).unsqueeze(2)
5257
# batch_size x 1 x len_q x len_k
53-
src_k = src_k.repeat(1, 1, length, 1)
58+
k = k.repeat(1, 1, len_q, 1)
5459

5560
# batch_size x 1 x len_q x 1
56-
src_q = src.ne(self.src_pad_idx).unsqueeze(1).unsqueeze(3)
61+
q = q.ne(self.src_pad_idx).unsqueeze(1).unsqueeze(3)
5762
# batch_size x 1 x len_q x len_k
58-
src_q = src_q.repeat(1, 1, 1, length)
63+
q = q.repeat(1, 1, 1, len_k)
64+
65+
mask = k & q
66+
return mask
67+
68+
def make_no_peak_mask(self, q, k):
69+
len_q, len_k = q.size(1), k.size(1)
5970

60-
src_mask = src_k & src_q
61-
return src_mask
71+
# len_q x len_k
72+
mask = torch.tril(torch.ones(len_q, len_k)).type(torch.BoolTensor).to(self.device)
6273

63-
def make_trg_mask(self, trg):
64-
trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(3)
65-
trg_len = trg.shape[1]
66-
trg_sub_mask = torch.tril(torch.ones(trg_len, trg_len)).type(torch.ByteTensor).to(self.device)
67-
trg_mask = trg_pad_mask & trg_sub_mask
68-
return trg_mask
74+
return mask

train.py

-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ def initialize_weights(m):
4747
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,
4848
verbose=True,
4949
factor=factor,
50-
min_lr=min_lr,
5150
patience=patience)
5251

5352
criterion = nn.CrossEntropyLoss(ignore_index=src_pad_idx)
268 Bytes
Binary file not shown.

util/__pycache__/bleu.cpython-37.pyc

2.17 KB
Binary file not shown.
1.86 KB
Binary file not shown.
486 Bytes
Binary file not shown.
1.37 KB
Binary file not shown.

util/data_loader.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
@when : 2019-10-29
44
@homepage : https://github.com/gusdnd852
55
"""
6-
from torchtext.data import Field, BucketIterator
7-
from torchtext.datasets.translation import Multi30k
6+
from torchtext.legacy.data import Field, BucketIterator
7+
from torchtext.legacy.datasets.translation import Multi30k
88

99

1010
class DataLoader:

util/tokenizer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
class Tokenizer:
1010

1111
def __init__(self):
12-
self.spacy_de = spacy.load('de')
13-
self.spacy_en = spacy.load('en')
12+
self.spacy_de = spacy.load('de_core_news_sm')
13+
self.spacy_en = spacy.load('en_core_web_sm')
1414

1515
def tokenize_de(self, text):
1616
"""

0 commit comments

Comments
 (0)