-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdecoder.py
65 lines (49 loc) · 3.48 KB
/
decoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import tensorflow as tf
from utils.positional_encoding import positional_encoding
from layers.decoder_layer import DecoderLayer
class Decoder(tf.keras.layers.Layer):
def __init__(self,
target_vocab_size,
embedding_dimension,
number_of_attention_heads,
num_stacked_decoders,
ffn_units,
dropout_rate,
dropout_training,
layernorm_epsilon,
name = "Decoder"):
super(Decoder, self).__init__(name = name)
self.target_vocab_size = target_vocab_size
self.embedding_dimension = embedding_dimension
self.number_of_attention_heads = number_of_attention_heads
self.num_stacked_decoders = num_stacked_decoders
self.ffn_units = ffn_units
self.dropout_rate = dropout_rate
self.dropout_training = dropout_training
self.layernorm_epsilon = layernorm_epsilon
self.target_embedding_layer = tf.keras.layers.Embedding(self.target_vocab_size, self.embedding_dimension)
self.decoder_pos_encodings = positional_encoding(self.target_vocab_size, self.embedding_dimension)
self.stacked_decoder_layers = [DecoderLayer(embedding_dimension = self.embedding_dimension,
num_attention_heads = self.number_of_attention_heads,
ffn_units = self.ffn_units,
dropout_rate = self.dropout_rate,
dropout_training = self.dropout_training,
layernorm_epsilon=self.layernorm_epsilon,
name = f"decoder_{dec_name + 1}") for dec_name in range(self.num_stacked_decoders)]
self.dropout = tf.keras.layers.Dropout(self.dropout_rate)
def call(self, decoder_input, encoder_output, peek_ahead_mask, decoder_padding_mask):
decoder_target_seq_length = tf.shape(decoder_input)[1]
attention_weights = {}
decoder_input_embeddings = self.target_embedding_layer(decoder_input)
decoder_input_embeddings *= tf.math.sqrt(tf.cast(self.embedding_dimension, tf.float32))
decoder_input_embeddings += self.decoder_pos_encodings[:, :decoder_target_seq_length, :]
decoder_input_embeddings = self.dropout(decoder_input_embeddings,
training = self.dropout_training)
for i in range(self.num_stacked_decoders):
decoder_input_embeddings, self_attention_scores, cross_attention_scores = self.stacked_decoder_layers[i](decoder_input_embeddings,
encoder_output,
peek_ahead_mask,
decoder_padding_mask)
attention_weights['decoder_layer_self_attention_{}'.format(i + 1)] = self_attention_scores
attention_weights['decoder_layer_cross_attention_{}'.format(i + 1)] = cross_attention_scores
return decoder_input_embeddings, attention_weights