Skip to content

Commit 069519d

Browse files
KumoLiupre-commit-ci[bot]ericspod
authored
Add include_fc and use_combined_linear argument in the SABlock (#7996)
Fixes #7991 Fixes #7992 ### Description Add `include_fc` and `use_combined_linear` argument in the `SABlock`. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
1 parent 0bb05d7 commit 069519d

13 files changed

+426
-137
lines changed

monai/networks/blocks/crossattention.py

+11-22
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,12 @@ def __init__(
5959
causal (bool, optional): whether to use causal attention.
6060
sequence_length (int, optional): if causal is True, it is necessary to specify the sequence length.
6161
rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map. For now only
62-
"decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported.
62+
"decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported.
6363
input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative positional
64-
parameter size.
64+
parameter size.
6565
attention_dtype: cast attention operations to this dtype.
66-
use_flash_attention: if True, use Pytorch's inbuilt
67-
flash attention for a memory efficient attention mechanism (see
68-
https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
66+
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
67+
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
6968
"""
7069

7170
super().__init__()
@@ -109,7 +108,7 @@ def __init__(
109108
self.to_v = nn.Linear(self.context_input_size, inner_size, bias=qkv_bias)
110109
self.input_rearrange = Rearrange("b h (l d) -> b l h d", l=num_heads)
111110

112-
self.out_rearrange = Rearrange("b h l d -> b l (h d)")
111+
self.out_rearrange = Rearrange("b l h d -> b h (l d)")
113112
self.drop_output = nn.Dropout(dropout_rate)
114113
self.drop_weights = nn.Dropout(dropout_rate)
115114
self.dropout_rate = dropout_rate
@@ -152,31 +151,20 @@ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None):
152151
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
153152
b, t, c = x.size() # batch size, sequence length, embedding dimensionality (hidden_size)
154153

155-
q = self.to_q(x)
154+
q = self.input_rearrange(self.to_q(x))
156155
kv = context if context is not None else x
157156
_, kv_t, _ = kv.size()
158-
k = self.to_k(kv)
159-
v = self.to_v(kv)
157+
k = self.input_rearrange(self.to_k(kv))
158+
v = self.input_rearrange(self.to_v(kv))
160159

161160
if self.attention_dtype is not None:
162161
q = q.to(self.attention_dtype)
163162
k = k.to(self.attention_dtype)
164163

165-
q = q.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs) #
166-
k = k.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs)
167-
v = v.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs)
168-
169164
if self.use_flash_attention:
170165
x = torch.nn.functional.scaled_dot_product_attention(
171-
query=q.transpose(1, 2),
172-
key=k.transpose(1, 2),
173-
value=v.transpose(1, 2),
174-
scale=self.scale,
175-
dropout_p=self.dropout_rate,
176-
is_causal=self.causal,
177-
).transpose(
178-
1, 2
179-
) # Back to (b, nh, t, hs)
166+
query=q, key=k, value=v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal
167+
)
180168
else:
181169
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale
182170
# apply relative positional embedding if defined
@@ -195,6 +183,7 @@ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None):
195183

196184
att_mat = self.drop_weights(att_mat)
197185
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
186+
198187
x = self.out_rearrange(x)
199188
x = self.out_proj(x)
200189
x = self.drop_output(x)

monai/networks/blocks/selfattention.py

+40-20
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from __future__ import annotations
1313

14-
from typing import Optional, Tuple
14+
from typing import Tuple, Union
1515

1616
import torch
1717
import torch.nn as nn
@@ -40,9 +40,11 @@ def __init__(
4040
hidden_input_size: int | None = None,
4141
causal: bool = False,
4242
sequence_length: int | None = None,
43-
rel_pos_embedding: Optional[str] = None,
44-
input_size: Optional[Tuple] = None,
45-
attention_dtype: Optional[torch.dtype] = None,
43+
rel_pos_embedding: str | None = None,
44+
input_size: Tuple | None = None,
45+
attention_dtype: torch.dtype | None = None,
46+
include_fc: bool = True,
47+
use_combined_linear: bool = True,
4648
use_flash_attention: bool = False,
4749
) -> None:
4850
"""
@@ -61,9 +63,10 @@ def __init__(
6163
input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative
6264
positional parameter size.
6365
attention_dtype: cast attention operations to this dtype.
64-
use_flash_attention: if True, use Pytorch's inbuilt
65-
flash attention for a memory efficient attention mechanism (see
66-
https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
66+
include_fc: whether to include the final linear layer. Default to True.
67+
use_combined_linear: whether to use a single linear layer for qkv projection, default to True.
68+
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
69+
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
6770
6871
"""
6972

@@ -105,9 +108,22 @@ def __init__(
105108
self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size
106109
self.out_proj = nn.Linear(self.inner_dim, self.hidden_input_size)
107110

108-
self.qkv = nn.Linear(self.hidden_input_size, self.inner_dim * 3, bias=qkv_bias)
109-
self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads)
110-
self.out_rearrange = Rearrange("b h l d -> b l (h d)")
111+
self.qkv: Union[nn.Linear, nn.Identity]
112+
self.to_q: Union[nn.Linear, nn.Identity]
113+
self.to_k: Union[nn.Linear, nn.Identity]
114+
self.to_v: Union[nn.Linear, nn.Identity]
115+
116+
if use_combined_linear:
117+
self.qkv = nn.Linear(self.hidden_input_size, self.inner_dim * 3, bias=qkv_bias)
118+
self.to_q = self.to_k = self.to_v = nn.Identity() # add to enable torchscript
119+
self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads)
120+
else:
121+
self.to_q = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias)
122+
self.to_k = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias)
123+
self.to_v = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias)
124+
self.qkv = nn.Identity() # add to enable torchscript
125+
self.input_rearrange = Rearrange("b h (l d) -> b l h d", l=num_heads)
126+
self.out_rearrange = Rearrange("b l h d -> b h (l d)")
111127
self.drop_output = nn.Dropout(dropout_rate)
112128
self.drop_weights = nn.Dropout(dropout_rate)
113129
self.dropout_rate = dropout_rate
@@ -117,6 +133,8 @@ def __init__(
117133
self.attention_dtype = attention_dtype
118134
self.causal = causal
119135
self.sequence_length = sequence_length
136+
self.include_fc = include_fc
137+
self.use_combined_linear = use_combined_linear
120138
self.use_flash_attention = use_flash_attention
121139

122140
if causal and sequence_length is not None:
@@ -144,22 +162,22 @@ def forward(self, x):
144162
Return:
145163
torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C
146164
"""
147-
output = self.input_rearrange(self.qkv(x))
148-
q, k, v = output[0], output[1], output[2]
165+
if self.use_combined_linear:
166+
output = self.input_rearrange(self.qkv(x))
167+
q, k, v = output[0], output[1], output[2]
168+
else:
169+
q = self.input_rearrange(self.to_q(x))
170+
k = self.input_rearrange(self.to_k(x))
171+
v = self.input_rearrange(self.to_v(x))
149172

150173
if self.attention_dtype is not None:
151174
q = q.to(self.attention_dtype)
152175
k = k.to(self.attention_dtype)
153176

154177
if self.use_flash_attention:
155178
x = F.scaled_dot_product_attention(
156-
query=q.transpose(1, 2),
157-
key=k.transpose(1, 2),
158-
value=v.transpose(1, 2),
159-
scale=self.scale,
160-
dropout_p=self.dropout_rate,
161-
is_causal=self.causal,
162-
).transpose(1, 2)
179+
query=q, key=k, value=v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal
180+
)
163181
else:
164182
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale
165183

@@ -179,7 +197,9 @@ def forward(self, x):
179197

180198
att_mat = self.drop_weights(att_mat)
181199
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
200+
182201
x = self.out_rearrange(x)
183-
x = self.out_proj(x)
202+
if self.include_fc:
203+
x = self.out_proj(x)
184204
x = self.drop_output(x)
185205
return x

monai/networks/blocks/spatialattention.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,13 @@ class SpatialAttentionBlock(nn.Module):
3232
spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
3333
num_channels: number of input channels. Must be divisible by num_head_channels.
3434
num_head_channels: number of channels per head.
35+
norm_num_groups: Number of groups for the group norm layer.
36+
norm_eps: Epsilon for the normalization.
3537
attention_dtype: cast attention operations to this dtype.
36-
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
38+
include_fc: whether to include the final linear layer. Default to True.
39+
use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
40+
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
41+
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
3742
3843
"""
3944

@@ -45,6 +50,8 @@ def __init__(
4550
norm_num_groups: int = 32,
4651
norm_eps: float = 1e-6,
4752
attention_dtype: Optional[torch.dtype] = None,
53+
include_fc: bool = True,
54+
use_combined_linear: bool = False,
4855
use_flash_attention: bool = False,
4956
) -> None:
5057
super().__init__()
@@ -60,6 +67,8 @@ def __init__(
6067
num_heads=num_heads,
6168
qkv_bias=True,
6269
attention_dtype=attention_dtype,
70+
include_fc=include_fc,
71+
use_combined_linear=use_combined_linear,
6372
use_flash_attention=use_flash_attention,
6473
)
6574

monai/networks/blocks/transformerblock.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ def __init__(
3737
sequence_length: int | None = None,
3838
with_cross_attention: bool = False,
3939
use_flash_attention: bool = False,
40+
include_fc: bool = True,
41+
use_combined_linear: bool = True,
4042
) -> None:
4143
"""
4244
Args:
@@ -47,7 +49,9 @@ def __init__(
4749
qkv_bias(bool, optional): apply bias term for the qkv linear layer. Defaults to False.
4850
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
4951
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
50-
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
52+
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
53+
include_fc: whether to include the final linear layer. Default to True.
54+
use_combined_linear: whether to use a single linear layer for qkv projection, default to True.
5155
5256
"""
5357

@@ -69,6 +73,8 @@ def __init__(
6973
save_attn=save_attn,
7074
causal=causal,
7175
sequence_length=sequence_length,
76+
include_fc=include_fc,
77+
use_combined_linear=use_combined_linear,
7278
use_flash_attention=use_flash_attention,
7379
)
7480
self.norm2 = nn.LayerNorm(hidden_size)

0 commit comments

Comments
 (0)