-
Notifications
You must be signed in to change notification settings - Fork 35
/
Copy pathoverride_decoder.py
63 lines (52 loc) · 2.28 KB
/
override_decoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from onmt_modules.decoder_transformer import TransformerDecoder
from onmt_modules.misc import sequence_mask
class OnmtDecoder_1(TransformerDecoder):
# overide forward
# without teacher forcing for stop
def forward(self, tgt, memory_bank, step=None, **kwargs):
"""Decode, possibly stepwise."""
if step == 0:
self._init_cache(memory_bank)
if step is None:
tgt_lens = kwargs["tgt_lengths"]
else:
tgt_words = kwargs["tgt_words"]
emb = self.embeddings(tgt, step=step)
assert emb.dim() == 3 # len x batch x embedding_dim
output = emb.transpose(0, 1).contiguous()
src_memory_bank = memory_bank.transpose(0, 1).contiguous()
pad_idx = self.embeddings.word_padding_idx
src_lens = kwargs["memory_lengths"]
src_max_len = self.state["src"].shape[0]
src_pad_mask = ~sequence_mask(src_lens, src_max_len).unsqueeze(1)
if step is None:
tgt_max_len = tgt_lens.max()
tgt_pad_mask = ~sequence_mask(tgt_lens, tgt_max_len).unsqueeze(1)
else:
tgt_pad_mask = tgt_words.data.eq(pad_idx).unsqueeze(1)
with_align = kwargs.pop('with_align', False)
attn_aligns = []
for i, layer in enumerate(self.transformer_layers):
layer_cache = self.state["cache"]["layer_{}".format(i)] \
if step is not None else None
output, attn, attn_align = layer(
output,
src_memory_bank,
src_pad_mask,
tgt_pad_mask,
layer_cache=layer_cache,
step=step,
with_align=with_align)
if attn_align is not None:
attn_aligns.append(attn_align)
output = self.layer_norm(output)
dec_outs = output.transpose(0, 1).contiguous()
attn = attn.transpose(0, 1).contiguous()
attns = {"std": attn}
if self._copy:
attns["copy"] = attn
if with_align:
attns["align"] = attn_aligns[self.alignment_layer] # `(B, Q, K)`
# attns["align"] = torch.stack(attn_aligns, 0).mean(0) # All avg
# TODO change the way attns is returned dict => list or tuple (onnx)
return dec_outs, attns