Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Models feature/normalization layers #47

Open
wants to merge 19 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
cf0cc85
Refactor to instantiate normalization layer.
jakob-schloer Nov 22, 2024
2079846
Fix dependencies for development
jakob-schloer Nov 22, 2024
8bc7d79
can pass arbitrary kernels via config
cathalobrien Dec 4, 2024
5505e97
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 5, 2024
ae41452
Merge remote-tracking branch 'origin/develop' into feature/normalizat…
jakob-schloer Dec 11, 2024
cb79197
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 11, 2024
ccafbb2
Set default behavior for layer_kernels.
jakob-schloer Dec 12, 2024
bd46af7
Merge branch 'feature/normalization-layers' of github.com:ecmwf/anemo…
jakob-schloer Dec 12, 2024
b9c1ff9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 12, 2024
2621880
Merge remote-tracking branch 'origin/develop' into feature/normalizat…
jakob-schloer Dec 17, 2024
4e35033
Add flexible layer kernels to GNN and GraphTransformer
jakob-schloer Dec 18, 2024
84dfdfc
Add type annotation.
jakob-schloer Dec 18, 2024
ad78c59
Add tests for layer kernels and adapt tests
jakob-schloer Dec 20, 2024
671e47c
Migrate models_feature/normalization-layers.
jakob-schloer Dec 20, 2024
f7d70d9
Add layer kernel to config.
jakob-schloer Dec 20, 2024
022e34d
Add conditional LN. Implementation originally written by @ssmmnn11
jakob-schloer Dec 20, 2024
cac4ad4
Including @ssmmnn11 comments.
jakob-schloer Jan 8, 2025
e4da1c0
Change docstring style
jakob-schloer Jan 8, 2025
26ca381
Remove comments in config.
jakob-schloer Jan 8, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions models/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ classifiers = [

dynamic = [ "version" ]
dependencies = [
# Fix certain dependencies during development
"anemoi-utils>=0.1.9",
"einops>=0.6.1",
"hydra-core>=1.3",
Expand Down
7 changes: 5 additions & 2 deletions models/src/anemoi/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from anemoi.models.distributed.transformer import shard_heads
from anemoi.models.distributed.transformer import shard_sequence
from anemoi.utils.config import DotDict

LOGGER = logging.getLogger(__name__)

Expand All @@ -38,6 +39,7 @@ def __init__(
self,
num_heads: int,
embed_dim: int,
layer_kernels: DotDict,
bias: bool = False,
is_causal: bool = False,
window_size: Optional[int] = None,
Expand All @@ -56,13 +58,14 @@ def __init__(
self.dropout_p = dropout_p
self.is_causal = is_causal

self.lin_qkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
linear = layer_kernels["Linear"]
self.lin_qkv = linear(embed_dim, 3 * embed_dim, bias=bias)
self.attention = attn_func

if not _FLASH_ATTENTION_AVAILABLE:
LOGGER.warning("Flash attention not available, falling back to pytorch scaled_dot_product_attention")

self.projection = nn.Linear(embed_dim, embed_dim, bias=True)
self.projection = linear(embed_dim, embed_dim, bias=True)

def forward(
self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None
Expand Down
82 changes: 57 additions & 25 deletions models/src/anemoi/models/layers/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from anemoi.models.layers.conv import GraphConv
from anemoi.models.layers.conv import GraphTransformerConv
from anemoi.models.layers.mlp import MLP
from anemoi.utils.config import DotDict

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -68,6 +69,7 @@ def __init__(
num_heads: int,
activation: str,
window_size: int,
layer_kernels: DotDict,
dropout_p: float = 0.0,
):
super().__init__()
Expand All @@ -78,7 +80,8 @@ def __init__(
LOGGER.error("Activation function %s not supported", activation)
raise RuntimeError from ae

self.layer_norm1 = nn.LayerNorm(num_channels)
self.layer_norm_attention = layer_kernels["LayerNorm"](normalized_shape=num_channels)
self.layer_norm_mlp = layer_kernels["LayerNorm"](normalized_shape=num_channels)

self.attention = MultiHeadSelfAttention(
num_heads=num_heads,
Expand All @@ -87,21 +90,21 @@ def __init__(
bias=False,
is_causal=False,
dropout_p=dropout_p,
layer_kernels=layer_kernels,
)

self.mlp = nn.Sequential(
nn.Linear(num_channels, hidden_dim),
layer_kernels["Linear"](num_channels, hidden_dim),
act_func(),
nn.Linear(hidden_dim, num_channels),
layer_kernels["Linear"](hidden_dim, num_channels),
)
self.layer_norm2 = nn.LayerNorm(num_channels)

def forward(
self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None
) -> Tensor:
# Need to be out of place for gradient propagation
x = x + self.attention(self.layer_norm1(x), shapes, batch_size, model_comm_group=model_comm_group)
x = x + self.mlp(self.layer_norm2(x))
x = x + self.attention(self.layer_norm_attention(x), shapes, batch_size, model_comm_group=model_comm_group)
x = x + self.mlp(self.layer_norm_mlp(x))
return x


Expand All @@ -112,6 +115,7 @@ def __init__(
self,
in_channels: int,
out_channels: int,
layer_kernels: DotDict,
mlp_extra_layers: int = 0,
activation: str = "SiLU",
update_src_nodes: bool = True,
Expand All @@ -126,6 +130,9 @@ def __init__(
Number of input channels.
out_channels : int
Number of output channels.
layer_kernels : DotDict
A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear"
Defined in config/models/<model>.yaml
mlp_extra_layers : int, optional
Extra layers in MLP, by default 0
activation : str, optional
Expand All @@ -144,13 +151,15 @@ def __init__(
2 * in_channels,
out_channels,
out_channels,
layer_kernels,
n_extra_layers=mlp_extra_layers,
activation=activation,
)

self.conv = GraphConv(
in_channels=in_channels,
out_channels=out_channels,
layer_kernels=layer_kernels,
mlp_extra_layers=mlp_extra_layers,
activation=activation,
)
Expand All @@ -173,6 +182,7 @@ def __ini__(
self,
in_channels: int,
out_channels: int,
layer_kernels: DotDict,
mlp_extra_layers: int = 0,
activation: str = "SiLU",
update_src_nodes: bool = True,
Expand All @@ -183,6 +193,7 @@ def __ini__(
self,
in_channels=in_channels,
out_channels=out_channels,
layer_kernels=layer_kernels,
mlp_extra_layers=mlp_extra_layers,
activation=activation,
update_src_nodes=update_src_nodes,
Expand Down Expand Up @@ -229,6 +240,7 @@ def __ini__(
self,
in_channels: int,
out_channels: int,
layer_kernels: DotDict,
mlp_extra_layers: int = 0,
activation: str = "SiLU",
update_src_nodes: bool = True,
Expand All @@ -239,6 +251,7 @@ def __ini__(
self,
in_channels=in_channels,
out_channels=out_channels,
layer_kernels=layer_kernels,
mlp_extra_layers=mlp_extra_layers,
activation=activation,
update_src_nodes=update_src_nodes,
Expand Down Expand Up @@ -295,6 +308,7 @@ def __init__(
hidden_dim: int,
out_channels: int,
edge_dim: int,
layer_kernels: DotDict,
num_heads: int = 16,
bias: bool = True,
activation: str = "GELU",
Expand All @@ -312,6 +326,9 @@ def __init__(
Number of output channels.
edge_dim : int,
Edge dimension
layer_kernels : DotDict
A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear"
Defined in config/models/<model>.yaml
num_heads : int,
Number of heads
bias : bool, by default True,
Expand All @@ -330,37 +347,40 @@ def __init__(

self.num_chunks = num_chunks

self.lin_key = nn.Linear(in_channels, num_heads * self.out_channels_conv)
self.lin_query = nn.Linear(in_channels, num_heads * self.out_channels_conv)
self.lin_value = nn.Linear(in_channels, num_heads * self.out_channels_conv)
self.lin_self = nn.Linear(in_channels, num_heads * self.out_channels_conv, bias=bias)
self.lin_edge = nn.Linear(edge_dim, num_heads * self.out_channels_conv) # , bias=False)
linear = layer_kernels["Linear"]
layerNorm = layer_kernels["LayerNorm"]
self.lin_key = linear(in_channels, num_heads * self.out_channels_conv)
self.lin_query = linear(in_channels, num_heads * self.out_channels_conv)
self.lin_value = linear(in_channels, num_heads * self.out_channels_conv)
self.lin_self = linear(in_channels, num_heads * self.out_channels_conv, bias=bias)
self.lin_edge = linear(edge_dim, num_heads * self.out_channels_conv) # , bias=False)

self.conv = GraphTransformerConv(out_channels=self.out_channels_conv)

self.projection = nn.Linear(out_channels, out_channels)
self.projection = linear(out_channels, out_channels)

try:
act_func = getattr(nn, activation)
except AttributeError as ae:
LOGGER.error("Activation function %s not supported", activation)
raise RuntimeError from ae

self.layer_norm_attention = layerNorm(normalized_shape=in_channels)
self.layer_norm_mlp = layerNorm(normalized_shape=out_channels)

self.node_dst_mlp = nn.Sequential(
nn.LayerNorm(out_channels),
nn.Linear(out_channels, hidden_dim),
self.layer_norm_mlp,
linear(out_channels, hidden_dim),
act_func(),
nn.Linear(hidden_dim, out_channels),
linear(hidden_dim, out_channels),
)

self.layer_norm1 = nn.LayerNorm(in_channels)

if self.update_src_nodes:
self.node_src_mlp = nn.Sequential(
nn.LayerNorm(out_channels),
nn.Linear(out_channels, hidden_dim),
self.layer_norm_mlp,
linear(out_channels, hidden_dim),
act_func(),
nn.Linear(hidden_dim, out_channels),
linear(hidden_dim, out_channels),
)

def shard_qkve_heads(
Expand Down Expand Up @@ -435,6 +455,7 @@ def __init__(
hidden_dim: int,
out_channels: int,
edge_dim: int,
layer_kernels: DotDict,
num_heads: int = 16,
bias: bool = True,
activation: str = "GELU",
Expand All @@ -452,6 +473,9 @@ def __init__(
Number of output channels.
edge_dim : int,
Edge dimension
layer_kernels : DotDict
A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear"
Defined in config/models/<model>.yaml
num_heads : int,
Number of heads
bias : bool, by default True,
Expand All @@ -466,6 +490,7 @@ def __init__(
hidden_dim=hidden_dim,
out_channels=out_channels,
edge_dim=edge_dim,
layer_kernels=layer_kernels,
num_heads=num_heads,
bias=bias,
activation=activation,
Expand All @@ -474,7 +499,8 @@ def __init__(
**kwargs,
)

self.layer_norm2 = nn.LayerNorm(in_channels)
self.layer_norm_attention_src = self.layer_norm_attention
self.layer_norm_attention_dest = layer_kernels["LayerNorm"](normalized_shape=in_channels)

def forward(
self,
Expand All @@ -489,9 +515,10 @@ def forward(
x_skip = x

x = (
self.layer_norm1(x[0]),
self.layer_norm2(x[1]),
) # Why does this use layer_norm2? And only is a mapper thing?
self.layer_norm_attention_src(x[0]),
self.layer_norm_attention_dest(x[1]),
)

x_r = self.lin_self(x[1])
query = self.lin_query(x[1])
key = self.lin_key(x[0])
Expand Down Expand Up @@ -559,6 +586,7 @@ def __init__(
hidden_dim: int,
out_channels: int,
edge_dim: int,
layer_kernels: DotDict,
num_heads: int = 16,
bias: bool = True,
activation: str = "GELU",
Expand All @@ -576,6 +604,9 @@ def __init__(
Number of output channels.
edge_dim : int,
Edge dimension
layer_kernels : DotDict
A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear"
Defined in config/models/<model>.yaml
num_heads : int,
Number of heads
bias : bool, by default True,
Expand All @@ -591,6 +622,7 @@ def __init__(
hidden_dim=hidden_dim,
out_channels=out_channels,
edge_dim=edge_dim,
layer_kernels=layer_kernels,
num_heads=num_heads,
bias=bias,
activation=activation,
Expand All @@ -611,7 +643,7 @@ def forward(
):
x_skip = x

x = self.layer_norm1(x)
x = self.layer_norm_attention(x)
x_r = self.lin_self(x)
query = self.lin_query(x)
key = self.lin_key(x)
Expand Down
Loading
Loading