Skip to content

Commit 391c9cd

Browse files
author
serotoninpm
committed
modify transpose
1 parent 1d2e33f commit 391c9cd

File tree

3 files changed

+5
-4
lines changed

3 files changed

+5
-4
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ venv/
44
.data/
55
*.pt
66
__pycache__
7+
result/

models/layers/multi_head_attention.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def split(self, tensor):
4848
batch_size, length, d_model = tensor.size()
4949

5050
d_tensor = d_model // self.n_head
51-
tensor = tensor.view(batch_size, self.n_head, length, d_tensor)
51+
tensor = tensor.view(batch_size, length, self.n_head, d_tensor).transpose(1, 2)
5252
# it is similar with group convolution (split by number of heads)
5353

5454
return tensor
@@ -63,5 +63,5 @@ def concat(self, tensor):
6363
batch_size, head, length, d_tensor = tensor.size()
6464
d_model = head * d_tensor
6565

66-
tensor = tensor.view(batch_size, length, d_model)
66+
tensor = tensor.transpose(1, 2).contiguous().view(batch_size, length, d_model)
6767
return tensor

models/layers/scale_dot_product_attention.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,15 @@ class ScaleDotProductAttention(nn.Module):
1919

2020
def __init__(self):
2121
super(ScaleDotProductAttention, self).__init__()
22-
self.softmax = nn.Softmax()
22+
self.softmax = nn.Softmax(dim=-1)
2323

2424
def forward(self, q, k, v, mask=None, e=1e-12):
2525
# input is 4 dimension tensor
2626
# [batch_size, head, length, d_tensor]
2727
batch_size, head, length, d_tensor = k.size()
2828

2929
# 1. dot product Query with Key^T to compute similarity
30-
k_t = k.view(batch_size, head, d_tensor, length) # transpose
30+
k_t = k.transpose(2, 3) # transpose
3131
score = (q @ k_t) / math.sqrt(d_tensor) # scaled dot product
3232

3333
# 2. apply masking (opt)

0 commit comments

Comments
 (0)