From cf0cc85f6b35220467199750f826fe84240c69a5 Mon Sep 17 00:00:00 2001 From: Jakob Schloer Date: Fri, 22 Nov 2024 14:11:23 +0000 Subject: [PATCH 01/15] Refactor to instantiate normalization layer. --- src/anemoi/models/layers/block.py | 13 ++++++---- src/anemoi/models/layers/chunk.py | 10 ++++++-- src/anemoi/models/layers/mlp.py | 2 +- src/anemoi/models/layers/normalization.py | 31 +++++++++++++++++++++++ src/anemoi/models/layers/processor.py | 4 ++- src/anemoi/models/layers/utils.py | 15 ----------- 6 files changed, 51 insertions(+), 24 deletions(-) create mode 100644 src/anemoi/models/layers/normalization.py diff --git a/src/anemoi/models/layers/block.py b/src/anemoi/models/layers/block.py index 60446d6c..81ee3204 100644 --- a/src/anemoi/models/layers/block.py +++ b/src/anemoi/models/layers/block.py @@ -69,6 +69,7 @@ def __init__( activation: str, window_size: int, dropout_p: float = 0.0, + layer_norm: Optional[dict] = None ): super().__init__() @@ -78,7 +79,8 @@ def __init__( LOGGER.error("Activation function %s not supported", activation) raise RuntimeError from ae - self.layer_norm1 = nn.LayerNorm(num_channels) + self.layer_norm1 = layer_norm() if layer_norm else nn.LayerNorm(num_channels) + self.layer_norm2 = layer_norm() if layer_norm else nn.LayerNorm(num_channels) self.attention = MultiHeadSelfAttention( num_heads=num_heads, @@ -94,14 +96,15 @@ def __init__( act_func(), nn.Linear(hidden_dim, num_channels), ) - self.layer_norm2 = nn.LayerNorm(num_channels) def forward( - self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None + self, x: Tensor, shapes: list, batch_size: int, + model_comm_group: Optional[ProcessGroup] = None, + **kwargs ) -> Tensor: # Need to be out of place for gradient propagation - x = x + self.attention(self.layer_norm1(x), shapes, batch_size, model_comm_group=model_comm_group) - x = x + self.mlp(self.layer_norm2(x)) + x = x + self.attention(self.layer_norm1(x, **kwargs), shapes, batch_size, model_comm_group=model_comm_group) + x = x + self.mlp(self.layer_norm2(x, **kwargs)) return x diff --git a/src/anemoi/models/layers/chunk.py b/src/anemoi/models/layers/chunk.py index 5c4fae38..44dfc76a 100644 --- a/src/anemoi/models/layers/chunk.py +++ b/src/anemoi/models/layers/chunk.py @@ -75,6 +75,7 @@ def __init__( mlp_hidden_ratio: int = 4, activation: str = "GELU", dropout_p: float = 0.0, + layer_norm: Optional[dict] = None, ) -> None: """Initialize TransformerProcessor. @@ -103,13 +104,18 @@ def __init__( activation=activation, window_size=window_size, dropout_p=dropout_p, + layer_norm=layer_norm, ) def forward( - self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None + self, x: Tensor, shapes: list, batch_size: int, + model_comm_group: Optional[ProcessGroup] = None, + **kwargs, ) -> Tensor: for i in range(self.num_layers): - x = self.blocks[i](x, shapes, batch_size, model_comm_group=model_comm_group) + x = self.blocks[i](x, shapes, batch_size, + model_comm_group=model_comm_group, + **kwargs) return (x,) # return tuple for consistency with other processors diff --git a/src/anemoi/models/layers/mlp.py b/src/anemoi/models/layers/mlp.py index 4a1e7957..1d939704 100644 --- a/src/anemoi/models/layers/mlp.py +++ b/src/anemoi/models/layers/mlp.py @@ -13,7 +13,7 @@ import torch from torch import nn -from anemoi.models.layers.utils import AutocastLayerNorm +from anemoi.models.layers.normalization import AutocastLayerNorm from anemoi.models.layers.utils import CheckpointWrapper LOGGER = logging.getLogger(__name__) diff --git a/src/anemoi/models/layers/normalization.py b/src/anemoi/models/layers/normalization.py new file mode 100644 index 00000000..551d90c8 --- /dev/null +++ b/src/anemoi/models/layers/normalization.py @@ -0,0 +1,31 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +from abc import ABC +from abc import abstractmethod + +import torch +from torch import nn + + +class AutocastLayerNorm(nn.LayerNorm): + """LayerNorm that casts the output back to the input type.""" + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, x: Tensor) -> Tensor: + """Forward with explicit autocast back to the input type. + + This casts the output to (b)float16 (instead of float32) when we run in mixed + precision. + """ + return super().forward(x).type_as(x) diff --git a/src/anemoi/models/layers/processor.py b/src/anemoi/models/layers/processor.py index 4fd32311..0afdf927 100644 --- a/src/anemoi/models/layers/processor.py +++ b/src/anemoi/models/layers/processor.py @@ -97,6 +97,7 @@ def __init__( num_heads: int = 16, mlp_hidden_ratio: int = 4, dropout_p: float = 0.1, + layer_norm: Optional[dict] = None, **kwargs, ) -> None: """Initialize TransformerProcessor. @@ -138,6 +139,7 @@ def __init__( window_size=window_size, activation=activation, dropout_p=dropout_p, + layer_norm=layer_norm, ) self.offload_layers(cpu_offload) @@ -157,7 +159,7 @@ def forward( model_comm_group.size() == 1 or batch_size == 1 ), "Only batch size of 1 is supported when model is sharded accross GPUs" - (x,) = self.run_layers((x,), shape_nodes, batch_size, model_comm_group) + (x,) = self.run_layers((x,), shape_nodes, batch_size, model_comm_group, **kwargs) return x diff --git a/src/anemoi/models/layers/utils.py b/src/anemoi/models/layers/utils.py index e243874a..f35c2b8b 100644 --- a/src/anemoi/models/layers/utils.py +++ b/src/anemoi/models/layers/utils.py @@ -22,18 +22,3 @@ def __init__(self, module: nn.Module) -> None: def forward(self, *args, **kwargs): return checkpoint(self.module, *args, **kwargs, use_reentrant=False) - - -class AutocastLayerNorm(nn.LayerNorm): - """LayerNorm that casts the output back to the input type.""" - - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - - def forward(self, x: Tensor) -> Tensor: - """Forward with explicit autocast back to the input type. - - This casts the output to (b)float16 (instead of float32) when we run in mixed - precision. - """ - return super().forward(x).type_as(x) From 207984679d015852ca45b51f2aa2dc67a4aadf7b Mon Sep 17 00:00:00 2001 From: Jakob Schloer Date: Fri, 22 Nov 2024 15:12:33 +0000 Subject: [PATCH 02/15] Fix dependencies for development --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 6d473472..45d56709 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,8 @@ classifiers = [ dynamic = [ "version" ] dependencies = [ + # Fix certain dependencies during development + "anemoi-training @ git+https://github.com/ecmwf/anemoi-training.git@25abf5e143a29d5931ccb4ac42a5f83c5cd26851", "anemoi-utils>=0.1.9", "einops>=0.6.1", "hydra-core>=1.3", From 8bc7d79fc8d6e2a119974c77d9c8124a86e59077 Mon Sep 17 00:00:00 2001 From: Cathal OBrien Date: Wed, 4 Dec 2024 10:05:32 +0000 Subject: [PATCH 03/15] can pass arbitrary kernels via config --- src/anemoi/models/layers/attention.py | 8 ++- src/anemoi/models/layers/block.py | 71 +++++++++++++------ src/anemoi/models/layers/chunk.py | 21 +++++- src/anemoi/models/layers/mapper.py | 21 +++++- src/anemoi/models/layers/mlp.py | 17 +++-- src/anemoi/models/layers/processor.py | 9 ++- .../models/encoder_processor_decoder.py | 30 ++++++++ 7 files changed, 143 insertions(+), 34 deletions(-) diff --git a/src/anemoi/models/layers/attention.py b/src/anemoi/models/layers/attention.py index d7f54920..dc859b47 100644 --- a/src/anemoi/models/layers/attention.py +++ b/src/anemoi/models/layers/attention.py @@ -28,6 +28,8 @@ from anemoi.models.distributed.transformer import shard_heads from anemoi.models.distributed.transformer import shard_sequence +from anemoi.utils.config import DotDict + LOGGER = logging.getLogger(__name__) @@ -38,6 +40,7 @@ def __init__( self, num_heads: int, embed_dim: int, + layer_kernels: DotDict, bias: bool = False, is_causal: bool = False, window_size: Optional[int] = None, @@ -56,13 +59,14 @@ def __init__( self.dropout_p = dropout_p self.is_causal = is_causal - self.lin_qkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias) + linear=layer_kernels["Linear"] + self.lin_qkv = linear(embed_dim, 3 * embed_dim, bias=bias) self.attention = attn_func if not _FLASH_ATTENTION_AVAILABLE: LOGGER.warning("Flash attention not available, falling back to pytorch scaled_dot_product_attention") - self.projection = nn.Linear(embed_dim, embed_dim, bias=True) + self.projection = linear(embed_dim, embed_dim, bias=True) def forward( self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None diff --git a/src/anemoi/models/layers/block.py b/src/anemoi/models/layers/block.py index 81ee3204..8296f1b4 100644 --- a/src/anemoi/models/layers/block.py +++ b/src/anemoi/models/layers/block.py @@ -32,6 +32,7 @@ from anemoi.models.layers.conv import GraphConv from anemoi.models.layers.conv import GraphTransformerConv from anemoi.models.layers.mlp import MLP +from anemoi.utils.config import DotDict LOGGER = logging.getLogger(__name__) @@ -68,8 +69,8 @@ def __init__( num_heads: int, activation: str, window_size: int, + layer_kernels: DotDict, dropout_p: float = 0.0, - layer_norm: Optional[dict] = None ): super().__init__() @@ -79,8 +80,8 @@ def __init__( LOGGER.error("Activation function %s not supported", activation) raise RuntimeError from ae - self.layer_norm1 = layer_norm() if layer_norm else nn.LayerNorm(num_channels) - self.layer_norm2 = layer_norm() if layer_norm else nn.LayerNorm(num_channels) + self.layer_norm1 = layer_kernels["LayerNorm"](num_channels) + self.layer_norm2 = layer_kernels["LayerNorm"](num_channels) self.attention = MultiHeadSelfAttention( num_heads=num_heads, @@ -89,12 +90,13 @@ def __init__( bias=False, is_causal=False, dropout_p=dropout_p, + layer_kernels=layer_kernels, ) self.mlp = nn.Sequential( - nn.Linear(num_channels, hidden_dim), + layer_kernels["Linear"](num_channels, hidden_dim), act_func(), - nn.Linear(hidden_dim, num_channels), + layer_kernels["Linear"](hidden_dim, num_channels), ) def forward( @@ -103,8 +105,8 @@ def forward( **kwargs ) -> Tensor: # Need to be out of place for gradient propagation - x = x + self.attention(self.layer_norm1(x, **kwargs), shapes, batch_size, model_comm_group=model_comm_group) - x = x + self.mlp(self.layer_norm2(x, **kwargs)) + x = x + self.attention(self.layer_norm1(x), shapes, batch_size, model_comm_group=model_comm_group) + x = x + self.mlp(self.layer_norm2(x)) return x @@ -115,6 +117,7 @@ def __init__( self, in_channels: int, out_channels: int, + layer_kernels: DotDict, mlp_extra_layers: int = 0, activation: str = "SiLU", update_src_nodes: bool = True, @@ -129,6 +132,9 @@ def __init__( Number of input channels. out_channels : int Number of output channels. + layer_kernels : DotDict + A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear" + Defined in config/models/.yaml mlp_extra_layers : int, optional Extra layers in MLP, by default 0 activation : str, optional @@ -147,6 +153,7 @@ def __init__( 2 * in_channels, out_channels, out_channels, + layer_kernels, n_extra_layers=mlp_extra_layers, activation=activation, ) @@ -176,6 +183,7 @@ def __ini__( self, in_channels: int, out_channels: int, + layer_kernels: DotDict, mlp_extra_layers: int = 0, activation: str = "SiLU", update_src_nodes: bool = True, @@ -190,6 +198,7 @@ def __ini__( activation=activation, update_src_nodes=update_src_nodes, num_chunks=num_chunks, + layer_kernels=layer_kernels, **kwargs, ) @@ -232,6 +241,7 @@ def __ini__( self, in_channels: int, out_channels: int, + layer_kernels: DotDict, mlp_extra_layers: int = 0, activation: str = "SiLU", update_src_nodes: bool = True, @@ -246,6 +256,7 @@ def __ini__( activation=activation, update_src_nodes=update_src_nodes, num_chunks=num_chunks, + layer_kernels=layer_kernels, **kwargs, ) @@ -298,6 +309,7 @@ def __init__( hidden_dim: int, out_channels: int, edge_dim: int, + layer_kernels: DotDict, num_heads: int = 16, bias: bool = True, activation: str = "GELU", @@ -315,6 +327,9 @@ def __init__( Number of output channels. edge_dim : int, Edge dimension + layer_kernels : DotDict + A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear" + Defined in config/models/.yaml num_heads : int, Number of heads bias : bool, by default True, @@ -333,15 +348,17 @@ def __init__( self.num_chunks = num_chunks - self.lin_key = nn.Linear(in_channels, num_heads * self.out_channels_conv) - self.lin_query = nn.Linear(in_channels, num_heads * self.out_channels_conv) - self.lin_value = nn.Linear(in_channels, num_heads * self.out_channels_conv) - self.lin_self = nn.Linear(in_channels, num_heads * self.out_channels_conv, bias=bias) - self.lin_edge = nn.Linear(edge_dim, num_heads * self.out_channels_conv) # , bias=False) + linear=layer_kernels['Linear'] + layerNorm=layer_kernels['LayerNorm'] + self.lin_key = linear(in_channels, num_heads * self.out_channels_conv) + self.lin_query = linear(in_channels, num_heads * self.out_channels_conv) + self.lin_value = linear(in_channels, num_heads * self.out_channels_conv) + self.lin_self = linear(in_channels, num_heads * self.out_channels_conv, bias=bias) + self.lin_edge = linear(edge_dim, num_heads * self.out_channels_conv) # , bias=False) self.conv = GraphTransformerConv(out_channels=self.out_channels_conv) - self.projection = nn.Linear(out_channels, out_channels) + self.projection = linear(out_channels, out_channels) try: act_func = getattr(nn, activation) @@ -350,20 +367,20 @@ def __init__( raise RuntimeError from ae self.node_dst_mlp = nn.Sequential( - nn.LayerNorm(out_channels), - nn.Linear(out_channels, hidden_dim), + layerNorm(out_channels), + linear(out_channels, hidden_dim), act_func(), - nn.Linear(hidden_dim, out_channels), + linear(hidden_dim, out_channels), ) - self.layer_norm1 = nn.LayerNorm(in_channels) + self.layer_norm1 = layerNorm(in_channels) if self.update_src_nodes: self.node_src_mlp = nn.Sequential( - nn.LayerNorm(out_channels), - nn.Linear(out_channels, hidden_dim), + layerNorm(out_channels), + linear(out_channels, hidden_dim), act_func(), - nn.Linear(hidden_dim, out_channels), + linear(hidden_dim, out_channels), ) def shard_qkve_heads( @@ -438,6 +455,7 @@ def __init__( hidden_dim: int, out_channels: int, edge_dim: int, + layer_kernels: DotDict, num_heads: int = 16, bias: bool = True, activation: str = "GELU", @@ -455,6 +473,9 @@ def __init__( Number of output channels. edge_dim : int, Edge dimension + layer_kernels : DotDict + A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear" + Defined in config/models/.yaml num_heads : int, Number of heads bias : bool, by default True, @@ -469,6 +490,7 @@ def __init__( hidden_dim=hidden_dim, out_channels=out_channels, edge_dim=edge_dim, + layer_kernels=layer_kernels, num_heads=num_heads, bias=bias, activation=activation, @@ -477,7 +499,7 @@ def __init__( **kwargs, ) - self.layer_norm2 = nn.LayerNorm(in_channels) + self.layer_norm2 = layer_kernels["LayerNorm"](in_channels) def forward( self, @@ -564,6 +586,7 @@ def __init__( hidden_dim: int, out_channels: int, edge_dim: int, + layer_kernels: DotDict, num_heads: int = 16, bias: bool = True, activation: str = "GELU", @@ -581,6 +604,9 @@ def __init__( Number of output channels. edge_dim : int, Edge dimension + layer_kernels : DotDict + A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear" + Defined in config/models/.yaml num_heads : int, Number of heads bias : bool, by default True, @@ -596,11 +622,12 @@ def __init__( hidden_dim=hidden_dim, out_channels=out_channels, edge_dim=edge_dim, + layer_kernels=layer_kernels, num_heads=num_heads, bias=bias, activation=activation, num_chunks=num_chunks, - update_src_nodes=update_src_nodes, + update_src_nodes=update_src_nodes **kwargs, ) diff --git a/src/anemoi/models/layers/chunk.py b/src/anemoi/models/layers/chunk.py index 44dfc76a..621101db 100644 --- a/src/anemoi/models/layers/chunk.py +++ b/src/anemoi/models/layers/chunk.py @@ -24,6 +24,7 @@ from anemoi.models.layers.block import GraphTransformerProcessorBlock from anemoi.models.layers.block import TransformerProcessorBlock from anemoi.models.layers.mlp import MLP +from anemoi.utils.config import DotDict LOGGER = logging.getLogger(__name__) @@ -37,6 +38,7 @@ def __init__( num_layers: int, *args, activation: str = "GELU", + layer_norm: Optional[dict] = None, **kwargs, ) -> None: """Initialize BaseProcessorChunk.""" @@ -71,11 +73,11 @@ def __init__( num_channels: int, num_layers: int, window_size: int, + layer_kernels: DotDict, num_heads: int = 16, mlp_hidden_ratio: int = 4, activation: str = "GELU", dropout_p: float = 0.0, - layer_norm: Optional[dict] = None, ) -> None: """Initialize TransformerProcessor. @@ -85,6 +87,11 @@ def __init__( Number of channels num_layers : int Number of layers + window_size: int, + 1/2 size of shifted window for attention computation + layer_kernels : DotDict + A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear" + Defined in config/models/.yaml num_heads: int Number of heads to use, default 16 mlp_hidden_ratio: int @@ -104,7 +111,7 @@ def __init__( activation=activation, window_size=window_size, dropout_p=dropout_p, - layer_norm=layer_norm, + layer_kernels=layer_kernels ) def forward( @@ -127,6 +134,7 @@ def __init__( self, num_channels: int, num_layers: int, + layer_kernels: DotDict, mlp_extra_layers: int = 0, activation: str = "SiLU", edge_dim: Optional[int] = None, @@ -139,6 +147,9 @@ def __init__( Channels of the message passing blocks. num_layers : int Number of message passing blocks. + layer_kernels : DotDict + A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear" + Defined in config/models/.yaml mlp_extra_layers : int, optional Extra num_layers in MLP, by default 0 activation : str, optional @@ -166,6 +177,7 @@ def __init__( num_channels, mlp_extra_layers=mlp_extra_layers, activation=activation, + layer_kernels=layer_kernels, ) def forward( @@ -194,6 +206,7 @@ def __init__( self, num_channels: int, num_layers: int, + layer_kernels: DotDict, num_heads: int = 16, mlp_hidden_ratio: int = 4, activation: str = "GELU", @@ -207,6 +220,9 @@ def __init__( Number of channels. num_layers : int Number of layers. + layer_kernels : DotDict + A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear" + Defined in config/models/.yaml num_heads: int Number of heads to use, default 16 mlp_hidden_ratio: int @@ -226,6 +242,7 @@ def __init__( num_heads=num_heads, edge_dim=edge_dim, activation=activation, + layer_kernels=layer_kernels, ) def forward( diff --git a/src/anemoi/models/layers/mapper.py b/src/anemoi/models/layers/mapper.py index 1ae45031..3ebcdff8 100644 --- a/src/anemoi/models/layers/mapper.py +++ b/src/anemoi/models/layers/mapper.py @@ -31,6 +31,7 @@ from anemoi.models.layers.block import GraphTransformerMapperBlock from anemoi.models.layers.graph import TrainableTensor from anemoi.models.layers.mlp import MLP +from anemoi.utils.config import DotDict LOGGER = logging.getLogger(__name__) @@ -190,6 +191,7 @@ def __init__( sub_graph_edge_attributes: Optional[list[str]] = None, src_grid_size: int = 0, dst_grid_size: int = 0, + layer_kernels: DotDict = None, ) -> None: """Initialize GraphTransformerBaseMapper. @@ -213,6 +215,9 @@ def __init__( Whether to offload processing to CPU, by default False out_channels_dst : Optional[int], optional Output channels of the destination node, by default None + layer_kernels : DotDict, optional + A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear" + Defined in config/models/.yaml """ super().__init__( in_channels_src, @@ -222,7 +227,11 @@ def __init__( num_chunks=num_chunks, cpu_offload=cpu_offload, activation=activation, + layer_kernels=layer_kernels, ) + + #Linear = layer_kernels.get("Linear", torch.nn.Linear) + Linear = layer_kernels["Linear"] self._register_edges(sub_graph, sub_graph_edge_attributes, src_grid_size, dst_grid_size, trainable_size) @@ -236,11 +245,12 @@ def __init__( edge_dim=self.edge_dim, activation=activation, num_chunks=num_chunks, + layer_kernels=layer_kernels ) self.offload_layers(cpu_offload) - self.emb_nodes_dst = nn.Linear(self.in_channels_dst, self.hidden_dim) + self.emb_nodes_dst = Linear(self.in_channels_dst, self.hidden_dim) def forward( self, @@ -291,6 +301,7 @@ def __init__( sub_graph_edge_attributes: Optional[list[str]] = None, src_grid_size: int = 0, dst_grid_size: int = 0, + layer_kernels: DotDict = None, ) -> None: """Initialize GraphTransformerForwardMapper. @@ -330,9 +341,10 @@ def __init__( sub_graph_edge_attributes=sub_graph_edge_attributes, src_grid_size=src_grid_size, dst_grid_size=dst_grid_size, + layer_kernels=layer_kernels, ) - self.emb_nodes_src = nn.Linear(self.in_channels_src, self.hidden_dim) + self.emb_nodes_src = layer_kernels["Linear"](self.in_channels_src, self.hidden_dim) def forward( self, @@ -364,6 +376,7 @@ def __init__( sub_graph_edge_attributes: Optional[list[str]] = None, src_grid_size: int = 0, dst_grid_size: int = 0, + layer_kernels: DotDict = None, ) -> None: """Initialize GraphTransformerBackwardMapper. @@ -387,6 +400,9 @@ def __init__( Whether to offload processing to CPU, by default False out_channels_dst : Optional[int], optional Output channels of the destination node, by default None + layer_kernels : DotDict + A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear" + Defined in config/models/.yaml """ super().__init__( in_channels_src, @@ -403,6 +419,7 @@ def __init__( sub_graph_edge_attributes=sub_graph_edge_attributes, src_grid_size=src_grid_size, dst_grid_size=dst_grid_size, + layer_kernels=layer_kernels, ) self.node_data_extractor = nn.Sequential( diff --git a/src/anemoi/models/layers/mlp.py b/src/anemoi/models/layers/mlp.py index 1d939704..e771d6e6 100644 --- a/src/anemoi/models/layers/mlp.py +++ b/src/anemoi/models/layers/mlp.py @@ -15,6 +15,7 @@ from anemoi.models.layers.normalization import AutocastLayerNorm from anemoi.models.layers.utils import CheckpointWrapper +from anemoi.utils.config import DotDict LOGGER = logging.getLogger(__name__) @@ -27,6 +28,7 @@ def __init__( in_features: int, hidden_dim: int, out_features: int, + layer_kernels: DotDict, n_extra_layers: int = 0, activation: str = "SiLU", final_activation: bool = False, @@ -43,6 +45,9 @@ def __init__( Hidden dimensions out_features : int Number of output features + layer_kernels : DotDict + A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear" + Defined in config/models/.yaml n_extra_layers : int, optional Number of extra layers in MLP, by default 0 activation : str, optional @@ -65,23 +70,27 @@ def __init__( If activation function is not supported """ super().__init__() + + Linear = layer_kernels["Linear"] + LayerNorm = layer_kernels["LayerNorm"] + try: act_func = getattr(nn, activation) except AttributeError as ae: LOGGER.error("Activation function %s not supported", activation) raise RuntimeError from ae - mlp1 = nn.Sequential(nn.Linear(in_features, hidden_dim), act_func()) + mlp1 = nn.Sequential(Linear(in_features, hidden_dim), act_func()) for _ in range(n_extra_layers + 1): - mlp1.append(nn.Linear(hidden_dim, hidden_dim)) + mlp1.append(Linear(hidden_dim, hidden_dim)) mlp1.append(act_func()) - mlp1.append(nn.Linear(hidden_dim, out_features)) + mlp1.append(Linear(hidden_dim, out_features)) if final_activation: mlp1.append(act_func()) if layer_norm: - mlp1.append(AutocastLayerNorm(out_features)) + mlp1.append(LayerNorm(out_features).as_type(out_features)) self.model = CheckpointWrapper(mlp1) if checkpoints else mlp1 diff --git a/src/anemoi/models/layers/processor.py b/src/anemoi/models/layers/processor.py index 0afdf927..a069ab3f 100644 --- a/src/anemoi/models/layers/processor.py +++ b/src/anemoi/models/layers/processor.py @@ -27,6 +27,7 @@ from anemoi.models.layers.chunk import TransformerProcessorChunk from anemoi.models.layers.graph import TrainableTensor from anemoi.models.layers.mapper import GraphEdgeMixin +from anemoi.utils.config import DotDict class BaseProcessor(nn.Module, ABC): @@ -88,6 +89,7 @@ class TransformerProcessor(BaseProcessor): def __init__( self, num_layers: int, + layer_kernels: DotDict, *args, window_size: Optional[int] = None, num_channels: int = 128, @@ -97,7 +99,6 @@ def __init__( num_heads: int = 16, mlp_hidden_ratio: int = 4, dropout_p: float = 0.1, - layer_norm: Optional[dict] = None, **kwargs, ) -> None: """Initialize TransformerProcessor. @@ -106,6 +107,9 @@ def __init__( ---------- num_layers : int Number of num_layers + layer_kernels : DotDict + A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear" + Defined in config/models/.yaml window_size: int, 1/2 size of shifted window for attention computation num_channels : int @@ -128,6 +132,7 @@ def __init__( cpu_offload=cpu_offload, num_heads=num_heads, mlp_hidden_ratio=mlp_hidden_ratio, + #layer_kernels=layer_kernels, ) self.build_layers( @@ -139,7 +144,7 @@ def __init__( window_size=window_size, activation=activation, dropout_p=dropout_p, - layer_norm=layer_norm, + layer_kernels=layer_kernels, ) self.offload_layers(cpu_offload) diff --git a/src/anemoi/models/models/encoder_processor_decoder.py b/src/anemoi/models/models/encoder_processor_decoder.py index c67c8c03..16dc9bd9 100644 --- a/src/anemoi/models/models/encoder_processor_decoder.py +++ b/src/anemoi/models/models/encoder_processor_decoder.py @@ -20,6 +20,7 @@ from torch.distributed.distributed_c10d import ProcessGroup from torch.utils.checkpoint import checkpoint from torch_geometric.data import HeteroData +from hydra.errors import InstantiationException from anemoi.models.distributed.shapes import get_shape_shards from anemoi.models.layers.graph import NamedNodesAttributes @@ -64,6 +65,9 @@ def __init__( self.node_attributes = NamedNodesAttributes(model_config.model.trainable_parameters.hidden, self._graph_data) input_dim = self.multi_step * self.num_input_channels + self.node_attributes.attr_ndims[self._graph_name_data] + + # read config.model.layer_kernels to get the implementation for certain layers + self._load_layer_kernels(model_config) # Encoder data -> hidden self.encoder = instantiate( @@ -74,6 +78,7 @@ def __init__( sub_graph=self._graph_data[(self._graph_name_data, "to", self._graph_name_hidden)], src_grid_size=self.node_attributes.num_nodes[self._graph_name_data], dst_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden], + layer_kernels=self.layer_kernels, ) # Processor hidden -> hidden @@ -83,6 +88,7 @@ def __init__( sub_graph=self._graph_data[(self._graph_name_hidden, "to", self._graph_name_hidden)], src_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden], dst_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden], + layer_kernels=self.layer_kernels, ) # Decoder hidden -> data @@ -95,6 +101,7 @@ def __init__( sub_graph=self._graph_data[(self._graph_name_hidden, "to", self._graph_name_data)], src_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden], dst_grid_size=self.node_attributes.num_nodes[self._graph_name_data], + layer_kernels=self.layer_kernels, ) # Instantiation of model output bounding functions (e.g., to ensure outputs like TP are positive definite) @@ -231,3 +238,26 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) -> x_out = bounding(x_out) return x_out + + def _load_layer_kernels(self, config: DotDict) -> None: + + # If self.layer_kernels entry is missing from the config, use torch.nn by default + default_kernels=DotDict() + default_kernels["Linear"] = DotDict({"_target_": "torch.nn.Linear", "_partial_": True}) + default_kernels["LayerNorm"] = DotDict({"_target_": "torch.nn.LayerNorm", "_partial_": True}) + + #self.layer_kernels = config.get("model.layer_kernels", default_kernels) #Always uses default kernels... + self.layer_kernels= config.model.layer_kernels + + # Loop through all kernels in the layer_kernels config entry and try import them + for kernel in self.layer_kernels: + kernel_entry = self.layer_kernels[kernel] + try: + instantiate(kernel_entry) + except InstantiationException: + LOGGER.info( + f"{kernel_entry['_target_']} not found! check your config.model.layer_kernel.{kernel} entry. Maybe your desired kernel is not installed or the import string is incorrect?" + ) + raise InstantiationException + else: + LOGGER.info(f"{kernel} kernel: {kernel_entry}") \ No newline at end of file From 5505e975a914bb18483e95d6e9adad4770167d1f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Dec 2024 12:48:22 +0000 Subject: [PATCH 04/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/anemoi/models/layers/attention.py | 6 +++--- src/anemoi/models/layers/block.py | 13 +++++-------- src/anemoi/models/layers/chunk.py | 13 +++++++------ src/anemoi/models/layers/mapper.py | 8 ++++---- src/anemoi/models/layers/mlp.py | 5 ++--- src/anemoi/models/layers/normalization.py | 3 --- src/anemoi/models/layers/processor.py | 4 ++-- src/anemoi/models/layers/utils.py | 1 - .../models/models/encoder_processor_decoder.py | 16 ++++++++-------- 9 files changed, 31 insertions(+), 38 deletions(-) diff --git a/src/anemoi/models/layers/attention.py b/src/anemoi/models/layers/attention.py index dc859b47..64cd7b69 100644 --- a/src/anemoi/models/layers/attention.py +++ b/src/anemoi/models/layers/attention.py @@ -25,11 +25,11 @@ else: _FLASH_ATTENTION_AVAILABLE = True +from anemoi.utils.config import DotDict + from anemoi.models.distributed.transformer import shard_heads from anemoi.models.distributed.transformer import shard_sequence -from anemoi.utils.config import DotDict - LOGGER = logging.getLogger(__name__) @@ -59,7 +59,7 @@ def __init__( self.dropout_p = dropout_p self.is_causal = is_causal - linear=layer_kernels["Linear"] + linear = layer_kernels["Linear"] self.lin_qkv = linear(embed_dim, 3 * embed_dim, bias=bias) self.attention = attn_func diff --git a/src/anemoi/models/layers/block.py b/src/anemoi/models/layers/block.py index 8296f1b4..edfd2f46 100644 --- a/src/anemoi/models/layers/block.py +++ b/src/anemoi/models/layers/block.py @@ -16,6 +16,7 @@ import einops import torch +from anemoi.utils.config import DotDict from torch import Tensor from torch import nn from torch.distributed.distributed_c10d import ProcessGroup @@ -32,7 +33,6 @@ from anemoi.models.layers.conv import GraphConv from anemoi.models.layers.conv import GraphTransformerConv from anemoi.models.layers.mlp import MLP -from anemoi.utils.config import DotDict LOGGER = logging.getLogger(__name__) @@ -100,9 +100,7 @@ def __init__( ) def forward( - self, x: Tensor, shapes: list, batch_size: int, - model_comm_group: Optional[ProcessGroup] = None, - **kwargs + self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None, **kwargs ) -> Tensor: # Need to be out of place for gradient propagation x = x + self.attention(self.layer_norm1(x), shapes, batch_size, model_comm_group=model_comm_group) @@ -348,8 +346,8 @@ def __init__( self.num_chunks = num_chunks - linear=layer_kernels['Linear'] - layerNorm=layer_kernels['LayerNorm'] + linear = layer_kernels["Linear"] + layerNorm = layer_kernels["LayerNorm"] self.lin_key = linear(in_channels, num_heads * self.out_channels_conv) self.lin_query = linear(in_channels, num_heads * self.out_channels_conv) self.lin_value = linear(in_channels, num_heads * self.out_channels_conv) @@ -627,8 +625,7 @@ def __init__( bias=bias, activation=activation, num_chunks=num_chunks, - update_src_nodes=update_src_nodes - **kwargs, + update_src_nodes=update_src_nodes**kwargs, ) def forward( diff --git a/src/anemoi/models/layers/chunk.py b/src/anemoi/models/layers/chunk.py index 621101db..2c2d0761 100644 --- a/src/anemoi/models/layers/chunk.py +++ b/src/anemoi/models/layers/chunk.py @@ -13,6 +13,7 @@ from abc import abstractmethod from typing import Optional +from anemoi.utils.config import DotDict from torch import Tensor from torch import nn from torch.distributed.distributed_c10d import ProcessGroup @@ -24,7 +25,6 @@ from anemoi.models.layers.block import GraphTransformerProcessorBlock from anemoi.models.layers.block import TransformerProcessorBlock from anemoi.models.layers.mlp import MLP -from anemoi.utils.config import DotDict LOGGER = logging.getLogger(__name__) @@ -111,18 +111,19 @@ def __init__( activation=activation, window_size=window_size, dropout_p=dropout_p, - layer_kernels=layer_kernels + layer_kernels=layer_kernels, ) def forward( - self, x: Tensor, shapes: list, batch_size: int, + self, + x: Tensor, + shapes: list, + batch_size: int, model_comm_group: Optional[ProcessGroup] = None, **kwargs, ) -> Tensor: for i in range(self.num_layers): - x = self.blocks[i](x, shapes, batch_size, - model_comm_group=model_comm_group, - **kwargs) + x = self.blocks[i](x, shapes, batch_size, model_comm_group=model_comm_group, **kwargs) return (x,) # return tuple for consistency with other processors diff --git a/src/anemoi/models/layers/mapper.py b/src/anemoi/models/layers/mapper.py index 3ebcdff8..c4537647 100644 --- a/src/anemoi/models/layers/mapper.py +++ b/src/anemoi/models/layers/mapper.py @@ -14,6 +14,7 @@ import numpy as np import torch +from anemoi.utils.config import DotDict from torch import Tensor from torch import nn from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import offload_wrapper @@ -31,7 +32,6 @@ from anemoi.models.layers.block import GraphTransformerMapperBlock from anemoi.models.layers.graph import TrainableTensor from anemoi.models.layers.mlp import MLP -from anemoi.utils.config import DotDict LOGGER = logging.getLogger(__name__) @@ -229,8 +229,8 @@ def __init__( activation=activation, layer_kernels=layer_kernels, ) - - #Linear = layer_kernels.get("Linear", torch.nn.Linear) + + # Linear = layer_kernels.get("Linear", torch.nn.Linear) Linear = layer_kernels["Linear"] self._register_edges(sub_graph, sub_graph_edge_attributes, src_grid_size, dst_grid_size, trainable_size) @@ -245,7 +245,7 @@ def __init__( edge_dim=self.edge_dim, activation=activation, num_chunks=num_chunks, - layer_kernels=layer_kernels + layer_kernels=layer_kernels, ) self.offload_layers(cpu_offload) diff --git a/src/anemoi/models/layers/mlp.py b/src/anemoi/models/layers/mlp.py index e771d6e6..5230b00c 100644 --- a/src/anemoi/models/layers/mlp.py +++ b/src/anemoi/models/layers/mlp.py @@ -11,11 +11,10 @@ import logging import torch +from anemoi.utils.config import DotDict from torch import nn -from anemoi.models.layers.normalization import AutocastLayerNorm from anemoi.models.layers.utils import CheckpointWrapper -from anemoi.utils.config import DotDict LOGGER = logging.getLogger(__name__) @@ -73,7 +72,7 @@ def __init__( Linear = layer_kernels["Linear"] LayerNorm = layer_kernels["LayerNorm"] - + try: act_func = getattr(nn, activation) except AttributeError as ae: diff --git a/src/anemoi/models/layers/normalization.py b/src/anemoi/models/layers/normalization.py index 551d90c8..be400f95 100644 --- a/src/anemoi/models/layers/normalization.py +++ b/src/anemoi/models/layers/normalization.py @@ -9,10 +9,7 @@ from __future__ import annotations -from abc import ABC -from abc import abstractmethod -import torch from torch import nn diff --git a/src/anemoi/models/layers/processor.py b/src/anemoi/models/layers/processor.py index a069ab3f..e892883b 100644 --- a/src/anemoi/models/layers/processor.py +++ b/src/anemoi/models/layers/processor.py @@ -11,6 +11,7 @@ from abc import ABC from typing import Optional +from anemoi.utils.config import DotDict from torch import Tensor from torch import nn from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import offload_wrapper @@ -27,7 +28,6 @@ from anemoi.models.layers.chunk import TransformerProcessorChunk from anemoi.models.layers.graph import TrainableTensor from anemoi.models.layers.mapper import GraphEdgeMixin -from anemoi.utils.config import DotDict class BaseProcessor(nn.Module, ABC): @@ -132,7 +132,7 @@ def __init__( cpu_offload=cpu_offload, num_heads=num_heads, mlp_hidden_ratio=mlp_hidden_ratio, - #layer_kernels=layer_kernels, + # layer_kernels=layer_kernels, ) self.build_layers( diff --git a/src/anemoi/models/layers/utils.py b/src/anemoi/models/layers/utils.py index f35c2b8b..6bec46aa 100644 --- a/src/anemoi/models/layers/utils.py +++ b/src/anemoi/models/layers/utils.py @@ -8,7 +8,6 @@ # nor does it submit to any jurisdiction. -from torch import Tensor from torch import nn from torch.utils.checkpoint import checkpoint diff --git a/src/anemoi/models/models/encoder_processor_decoder.py b/src/anemoi/models/models/encoder_processor_decoder.py index 16dc9bd9..fd5eb139 100644 --- a/src/anemoi/models/models/encoder_processor_decoder.py +++ b/src/anemoi/models/models/encoder_processor_decoder.py @@ -14,13 +14,13 @@ import einops import torch from anemoi.utils.config import DotDict +from hydra.errors import InstantiationException from hydra.utils import instantiate from torch import Tensor from torch import nn from torch.distributed.distributed_c10d import ProcessGroup from torch.utils.checkpoint import checkpoint from torch_geometric.data import HeteroData -from hydra.errors import InstantiationException from anemoi.models.distributed.shapes import get_shape_shards from anemoi.models.layers.graph import NamedNodesAttributes @@ -65,7 +65,7 @@ def __init__( self.node_attributes = NamedNodesAttributes(model_config.model.trainable_parameters.hidden, self._graph_data) input_dim = self.multi_step * self.num_input_channels + self.node_attributes.attr_ndims[self._graph_name_data] - + # read config.model.layer_kernels to get the implementation for certain layers self._load_layer_kernels(model_config) @@ -242,13 +242,13 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) -> def _load_layer_kernels(self, config: DotDict) -> None: # If self.layer_kernels entry is missing from the config, use torch.nn by default - default_kernels=DotDict() + default_kernels = DotDict() default_kernels["Linear"] = DotDict({"_target_": "torch.nn.Linear", "_partial_": True}) default_kernels["LayerNorm"] = DotDict({"_target_": "torch.nn.LayerNorm", "_partial_": True}) - - #self.layer_kernels = config.get("model.layer_kernels", default_kernels) #Always uses default kernels... - self.layer_kernels= config.model.layer_kernels - + + # self.layer_kernels = config.get("model.layer_kernels", default_kernels) #Always uses default kernels... + self.layer_kernels = config.model.layer_kernels + # Loop through all kernels in the layer_kernels config entry and try import them for kernel in self.layer_kernels: kernel_entry = self.layer_kernels[kernel] @@ -260,4 +260,4 @@ def _load_layer_kernels(self, config: DotDict) -> None: ) raise InstantiationException else: - LOGGER.info(f"{kernel} kernel: {kernel_entry}") \ No newline at end of file + LOGGER.info(f"{kernel} kernel: {kernel_entry}") From cb791973f8484000e78e0c88caee1bc5dbbd77ac Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 11 Dec 2024 14:04:05 +0000 Subject: [PATCH 05/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/anemoi/models/layers/normalization.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/anemoi/models/layers/normalization.py b/src/anemoi/models/layers/normalization.py index be400f95..7665c9ea 100644 --- a/src/anemoi/models/layers/normalization.py +++ b/src/anemoi/models/layers/normalization.py @@ -9,7 +9,6 @@ from __future__ import annotations - from torch import nn From ccafbb2d4e09ba57d6c22a7f68addf1e3d70aa1d Mon Sep 17 00:00:00 2001 From: Jakob Schloer Date: Thu, 12 Dec 2024 09:53:15 +0000 Subject: [PATCH 06/15] Set default behavior for layer_kernels. --- pyproject.toml | 1 - src/anemoi/models/layers/block.py | 28 ++++++++++--------- src/anemoi/models/layers/chunk.py | 3 +- src/anemoi/models/layers/processor.py | 2 +- .../models/encoder_processor_decoder.py | 15 +++++----- 5 files changed, 25 insertions(+), 24 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 45d56709..aacce2f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,6 @@ classifiers = [ dynamic = [ "version" ] dependencies = [ # Fix certain dependencies during development - "anemoi-training @ git+https://github.com/ecmwf/anemoi-training.git@25abf5e143a29d5931ccb4ac42a5f83c5cd26851", "anemoi-utils>=0.1.9", "einops>=0.6.1", "hydra-core>=1.3", diff --git a/src/anemoi/models/layers/block.py b/src/anemoi/models/layers/block.py index edfd2f46..cf04d815 100644 --- a/src/anemoi/models/layers/block.py +++ b/src/anemoi/models/layers/block.py @@ -13,6 +13,7 @@ from abc import ABC from abc import abstractmethod from typing import Optional +from hydra.utils import instantiate import einops import torch @@ -80,8 +81,8 @@ def __init__( LOGGER.error("Activation function %s not supported", activation) raise RuntimeError from ae - self.layer_norm1 = layer_kernels["LayerNorm"](num_channels) - self.layer_norm2 = layer_kernels["LayerNorm"](num_channels) + self.layer_norm_attention = layer_kernels["LayerNorm"](normalized_shape=num_channels) + self.layer_norm_mlp = layer_kernels["LayerNorm"](normalized_shape=num_channels) self.attention = MultiHeadSelfAttention( num_heads=num_heads, @@ -100,11 +101,11 @@ def __init__( ) def forward( - self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None, **kwargs + self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None ) -> Tensor: # Need to be out of place for gradient propagation - x = x + self.attention(self.layer_norm1(x), shapes, batch_size, model_comm_group=model_comm_group) - x = x + self.mlp(self.layer_norm2(x)) + x = x + self.attention(self.layer_norm_attention(x), shapes, batch_size, model_comm_group=model_comm_group) + x = x + self.mlp(self.layer_norm_mlp(x)) return x @@ -356,6 +357,7 @@ def __init__( self.conv = GraphTransformerConv(out_channels=self.out_channels_conv) + # Why does the GraphTransformer not have a layer_norm_mlp like the Transformer? self.projection = linear(out_channels, out_channels) try: @@ -365,17 +367,17 @@ def __init__( raise RuntimeError from ae self.node_dst_mlp = nn.Sequential( - layerNorm(out_channels), + layerNorm(normalized_shape=out_channels), linear(out_channels, hidden_dim), act_func(), linear(hidden_dim, out_channels), ) - self.layer_norm1 = layerNorm(in_channels) + self.layer_norm_attention = layerNorm(normalized_shape=in_channels) if self.update_src_nodes: self.node_src_mlp = nn.Sequential( - layerNorm(out_channels), + layerNorm(normlaized_shape=out_channels), linear(out_channels, hidden_dim), act_func(), linear(hidden_dim, out_channels), @@ -497,7 +499,7 @@ def __init__( **kwargs, ) - self.layer_norm2 = layer_kernels["LayerNorm"](in_channels) + self.layer_norm_attention_2 = layer_kernels["LayerNorm"](normalized_shape=in_channels) def forward( self, @@ -512,9 +514,9 @@ def forward( x_skip = x x = ( - self.layer_norm1(x[0]), - self.layer_norm2(x[1]), - ) # Why does this use layer_norm2? And only is a mapper thing? + self.layer_norm_attention(x[0]), + self.layer_norm_attention_2(x[1]), + ) # Why does this use layer_norm_attention_2? And only is a mapper thing? x_r = self.lin_self(x[1]) query = self.lin_query(x[1]) key = self.lin_key(x[0]) @@ -640,7 +642,7 @@ def forward( ): x_skip = x - x = self.layer_norm1(x) + x = self.layer_norm_attention(x) x_r = self.lin_self(x) query = self.lin_query(x) key = self.lin_key(x) diff --git a/src/anemoi/models/layers/chunk.py b/src/anemoi/models/layers/chunk.py index 2c2d0761..179b63b3 100644 --- a/src/anemoi/models/layers/chunk.py +++ b/src/anemoi/models/layers/chunk.py @@ -120,10 +120,9 @@ def forward( shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None, - **kwargs, ) -> Tensor: for i in range(self.num_layers): - x = self.blocks[i](x, shapes, batch_size, model_comm_group=model_comm_group, **kwargs) + x = self.blocks[i](x, shapes, batch_size, model_comm_group=model_comm_group) return (x,) # return tuple for consistency with other processors diff --git a/src/anemoi/models/layers/processor.py b/src/anemoi/models/layers/processor.py index e77be034..a90448a6 100644 --- a/src/anemoi/models/layers/processor.py +++ b/src/anemoi/models/layers/processor.py @@ -164,7 +164,7 @@ def forward( model_comm_group.size() == 1 or batch_size == 1 ), "Only batch size of 1 is supported when model is sharded accross GPUs" - (x,) = self.run_layers((x,), shape_nodes, batch_size, model_comm_group, **kwargs) + (x,) = self.run_layers((x,), shape_nodes, batch_size, model_comm_group) return x diff --git a/src/anemoi/models/models/encoder_processor_decoder.py b/src/anemoi/models/models/encoder_processor_decoder.py index fd5eb139..4c7a4406 100644 --- a/src/anemoi/models/models/encoder_processor_decoder.py +++ b/src/anemoi/models/models/encoder_processor_decoder.py @@ -16,6 +16,7 @@ from anemoi.utils.config import DotDict from hydra.errors import InstantiationException from hydra.utils import instantiate +from omegaconf import OmegaConf from torch import Tensor from torch import nn from torch.distributed.distributed_c10d import ProcessGroup @@ -241,13 +242,13 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) -> def _load_layer_kernels(self, config: DotDict) -> None: - # If self.layer_kernels entry is missing from the config, use torch.nn by default - default_kernels = DotDict() - default_kernels["Linear"] = DotDict({"_target_": "torch.nn.Linear", "_partial_": True}) - default_kernels["LayerNorm"] = DotDict({"_target_": "torch.nn.LayerNorm", "_partial_": True}) - - # self.layer_kernels = config.get("model.layer_kernels", default_kernels) #Always uses default kernels... - self.layer_kernels = config.model.layer_kernels + # If self.layer_kernels entry is missing from the config, use torch.nn kernels + default_kernels = { + "Linear": {"_target_": "torch.nn.Linear", "_partial_": True}, + "LayerNorm": {"_target_": "torch.nn.LayerNorm", "_partial_": True}, + } + user_kernel = OmegaConf.select(config, "model.layer_kernels") + self.layer_kernels = {**default_kernels, **user_kernel} # Loop through all kernels in the layer_kernels config entry and try import them for kernel in self.layer_kernels: From b9c1ff9f7ce5a211fc4cfb83222a8efe0e9aeddf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 12 Dec 2024 10:45:58 +0000 Subject: [PATCH 07/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/anemoi/models/layers/block.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/anemoi/models/layers/block.py b/src/anemoi/models/layers/block.py index cf04d815..a4ef1f6f 100644 --- a/src/anemoi/models/layers/block.py +++ b/src/anemoi/models/layers/block.py @@ -13,7 +13,6 @@ from abc import ABC from abc import abstractmethod from typing import Optional -from hydra.utils import instantiate import einops import torch From 4e350332e3a017a2b184dcc7cb4067078a4c403d Mon Sep 17 00:00:00 2001 From: Jakob Schloer Date: Wed, 18 Dec 2024 13:48:42 +0000 Subject: [PATCH 08/15] Add flexible layer kernels to GNN and GraphTransformer --- src/anemoi/models/layers/block.py | 18 +++++++++++------- src/anemoi/models/layers/chunk.py | 15 ++++++++------- src/anemoi/models/layers/conv.py | 6 ++++++ src/anemoi/models/layers/mapper.py | 21 ++++++++++++++++++++- src/anemoi/models/layers/mlp.py | 2 +- src/anemoi/models/layers/processor.py | 17 +++++++++-------- 6 files changed, 55 insertions(+), 24 deletions(-) diff --git a/src/anemoi/models/layers/block.py b/src/anemoi/models/layers/block.py index 988cef53..e06302f3 100644 --- a/src/anemoi/models/layers/block.py +++ b/src/anemoi/models/layers/block.py @@ -159,6 +159,7 @@ def __init__( self.conv = GraphConv( in_channels=in_channels, out_channels=out_channels, + layer_kernels=layer_kernels, mlp_extra_layers=mlp_extra_layers, activation=activation, ) @@ -192,11 +193,11 @@ def __ini__( self, in_channels=in_channels, out_channels=out_channels, + layer_kernels=layer_kernels, mlp_extra_layers=mlp_extra_layers, activation=activation, update_src_nodes=update_src_nodes, num_chunks=num_chunks, - layer_kernels=layer_kernels, **kwargs, ) @@ -250,11 +251,11 @@ def __ini__( self, in_channels=in_channels, out_channels=out_channels, + layer_kernels=layer_kernels, mlp_extra_layers=mlp_extra_layers, activation=activation, update_src_nodes=update_src_nodes, num_chunks=num_chunks, - layer_kernels=layer_kernels, **kwargs, ) @@ -365,18 +366,19 @@ def __init__( LOGGER.error("Activation function %s not supported", activation) raise RuntimeError from ae + self.layer_norm_attention = layerNorm(normalized_shape=in_channels) + self.layer_norm_mlp = layerNorm(normalized_shape=out_channels) + self.node_dst_mlp = nn.Sequential( - layerNorm(normalized_shape=out_channels), + self.layer_norm_mlp, linear(out_channels, hidden_dim), act_func(), linear(hidden_dim, out_channels), ) - self.layer_norm_attention = layerNorm(normalized_shape=in_channels) - if self.update_src_nodes: self.node_src_mlp = nn.Sequential( - layerNorm(normlaized_shape=out_channels), + self.layer_norm_mlp, linear(out_channels, hidden_dim), act_func(), linear(hidden_dim, out_channels), @@ -516,6 +518,7 @@ def forward( self.layer_norm_attention(x[0]), self.layer_norm_attention_2(x[1]), ) # Why does this use layer_norm_attention_2? And only is a mapper thing? + x_r = self.lin_self(x[1]) query = self.lin_query(x[1]) key = self.lin_key(x[0]) @@ -624,7 +627,8 @@ def __init__( bias=bias, activation=activation, num_chunks=num_chunks, - update_src_nodes=update_src_nodes**kwargs, + update_src_nodes=update_src_nodes, + **kwargs, ) def forward( diff --git a/src/anemoi/models/layers/chunk.py b/src/anemoi/models/layers/chunk.py index 179b63b3..ac0a0bce 100644 --- a/src/anemoi/models/layers/chunk.py +++ b/src/anemoi/models/layers/chunk.py @@ -72,8 +72,8 @@ def __init__( self, num_channels: int, num_layers: int, - window_size: int, layer_kernels: DotDict, + window_size: int, num_heads: int = 16, mlp_hidden_ratio: int = 4, activation: str = "GELU", @@ -87,11 +87,11 @@ def __init__( Number of channels num_layers : int Number of layers - window_size: int, - 1/2 size of shifted window for attention computation layer_kernels : DotDict A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear" Defined in config/models/.yaml + window_size: int, + 1/2 size of shifted window for attention computation num_heads: int Number of heads to use, default 16 mlp_hidden_ratio: int @@ -110,8 +110,8 @@ def __init__( num_heads=num_heads, activation=activation, window_size=window_size, - dropout_p=dropout_p, layer_kernels=layer_kernels, + dropout_p=dropout_p, ) def forward( @@ -165,6 +165,7 @@ def __init__( in_features=edge_dim, hidden_dim=num_channels, out_features=num_channels, + layer_kernels=layer_kernels, n_extra_layers=mlp_extra_layers, activation=activation, ) @@ -175,9 +176,9 @@ def __init__( GraphConvProcessorBlock, num_channels, num_channels, + layer_kernels=layer_kernels, mlp_extra_layers=mlp_extra_layers, activation=activation, - layer_kernels=layer_kernels, ) def forward( @@ -239,10 +240,10 @@ def __init__( in_channels=num_channels, hidden_dim=mlp_hidden_ratio * num_channels, out_channels=num_channels, - num_heads=num_heads, edge_dim=edge_dim, - activation=activation, + num_heads=num_heads, layer_kernels=layer_kernels, + activation=activation, ) def forward( diff --git a/src/anemoi/models/layers/conv.py b/src/anemoi/models/layers/conv.py index 6b3a767e..5c354502 100644 --- a/src/anemoi/models/layers/conv.py +++ b/src/anemoi/models/layers/conv.py @@ -11,6 +11,7 @@ from typing import Optional import torch +from anemoi.utils.config import DotDict from torch import Tensor from torch.nn.functional import dropout from torch_geometric.nn.conv import MessagePassing @@ -31,6 +32,7 @@ def __init__( self, in_channels: int, out_channels: int, + layer_kernels: DotDict, mlp_extra_layers: int = 0, activation: str = "SiLU", **kwargs, @@ -43,6 +45,9 @@ def __init__( Number of input channels. out_channels : int Number of output channels. + layer_kernels : DotDict + A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear" + Defined in config/models/.yaml mlp_extra_layers : int, optional Extra layers in MLP, by default 0 activation : str, optional @@ -54,6 +59,7 @@ def __init__( 3 * in_channels, out_channels, out_channels, + layer_kernels=layer_kernels, n_extra_layers=mlp_extra_layers, activation=activation, ) diff --git a/src/anemoi/models/layers/mapper.py b/src/anemoi/models/layers/mapper.py index c4537647..38c0c152 100644 --- a/src/anemoi/models/layers/mapper.py +++ b/src/anemoi/models/layers/mapper.py @@ -227,7 +227,6 @@ def __init__( num_chunks=num_chunks, cpu_offload=cpu_offload, activation=activation, - layer_kernels=layer_kernels, ) # Linear = layer_kernels.get("Linear", torch.nn.Linear) @@ -453,6 +452,7 @@ def __init__( sub_graph_edge_attributes: Optional[list[str]] = None, src_grid_size: int = 0, dst_grid_size: int = 0, + layer_kernels: DotDict = None, ) -> None: """Initialize GNNBaseMapper. @@ -476,6 +476,9 @@ def __init__( Whether to offload processing to CPU, by default False out_channels_dst : Optional[int], optional Output channels of the destination node, by default None + layer_kernels : DotDict + A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear" + Defined in config/models/.yaml """ super().__init__( in_channels_src, @@ -493,6 +496,7 @@ def __init__( in_features=self.edge_dim, hidden_dim=hidden_dim, out_features=hidden_dim, + layer_kernels=layer_kernels, n_extra_layers=mlp_extra_layers, activation=activation, ) @@ -557,6 +561,7 @@ def __init__( sub_graph_edge_attributes: Optional[list[str]] = None, src_grid_size: int = 0, dst_grid_size: int = 0, + layer_kernels: DotDict = None, ) -> None: """Initialize GNNForwardMapper. @@ -580,6 +585,9 @@ def __init__( Whether to offload processing to CPU, by default False out_channels_dst : Optional[int], optional Output channels of the destination node, by default None + layer_kernels : DotDict + A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear" + Defined in config/models/.yaml """ super().__init__( in_channels_src, @@ -595,11 +603,13 @@ def __init__( sub_graph_edge_attributes=sub_graph_edge_attributes, src_grid_size=src_grid_size, dst_grid_size=dst_grid_size, + layer_kernels=layer_kernels, ) self.proc = GraphConvMapperBlock( hidden_dim, hidden_dim, + layer_kernels=layer_kernels, mlp_extra_layers=mlp_extra_layers, activation=activation, update_src_nodes=True, @@ -612,6 +622,7 @@ def __init__( in_features=in_channels_src, hidden_dim=hidden_dim, out_features=hidden_dim, + layer_kernels=layer_kernels, n_extra_layers=mlp_extra_layers, activation=activation, ) @@ -620,6 +631,7 @@ def __init__( in_features=in_channels_dst, hidden_dim=hidden_dim, out_features=hidden_dim, + layer_kernels=layer_kernels, n_extra_layers=mlp_extra_layers, activation=activation, ) @@ -643,6 +655,7 @@ def __init__( sub_graph_edge_attributes: Optional[list[str]] = None, src_grid_size: int = 0, dst_grid_size: int = 0, + layer_kernels: DotDict = None, ) -> None: """Initialize GNNBackwardMapper. @@ -666,6 +679,9 @@ def __init__( Whether to offload processing to CPU, by default False out_channels_dst : Optional[int], optional Output channels of the destination node, by default None + layer_kernels : DotDict + A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear" + Defined in config/models/.yaml """ super().__init__( in_channels_src, @@ -681,11 +697,13 @@ def __init__( sub_graph_edge_attributes=sub_graph_edge_attributes, src_grid_size=src_grid_size, dst_grid_size=dst_grid_size, + layer_kernels=layer_kernels, ) self.proc = GraphConvMapperBlock( hidden_dim, hidden_dim, + layer_kernels=layer_kernels, mlp_extra_layers=mlp_extra_layers, activation=activation, update_src_nodes=False, @@ -698,6 +716,7 @@ def __init__( in_features=self.hidden_dim, hidden_dim=self.hidden_dim, out_features=self.out_channels_dst, + layer_kernels=layer_kernels, n_extra_layers=mlp_extra_layers, activation=self.activation, layer_norm=False, diff --git a/src/anemoi/models/layers/mlp.py b/src/anemoi/models/layers/mlp.py index 5230b00c..af2ce74c 100644 --- a/src/anemoi/models/layers/mlp.py +++ b/src/anemoi/models/layers/mlp.py @@ -89,7 +89,7 @@ def __init__( mlp1.append(act_func()) if layer_norm: - mlp1.append(LayerNorm(out_features).as_type(out_features)) + mlp1.append(LayerNorm(normalized_shape=out_features)) self.model = CheckpointWrapper(mlp1) if checkpoints else mlp1 diff --git a/src/anemoi/models/layers/processor.py b/src/anemoi/models/layers/processor.py index a90448a6..5b77fdc0 100644 --- a/src/anemoi/models/layers/processor.py +++ b/src/anemoi/models/layers/processor.py @@ -124,27 +124,26 @@ def __init__( Dropout probability used for multi-head self attention, default 0.0 """ super().__init__( - num_channels=num_channels, num_layers=num_layers, + num_channels=num_channels, window_size=window_size, num_chunks=num_chunks, activation=activation, cpu_offload=cpu_offload, num_heads=num_heads, mlp_hidden_ratio=mlp_hidden_ratio, - # layer_kernels=layer_kernels, ) self.build_layers( TransformerProcessorChunk, num_channels=num_channels, + num_layers=self.chunk_size, + layer_kernels=layer_kernels, mlp_hidden_ratio=mlp_hidden_ratio, num_heads=num_heads, - num_layers=self.chunk_size, window_size=window_size, activation=activation, dropout_p=dropout_p, - layer_kernels=layer_kernels, ) self.offload_layers(cpu_offload) @@ -175,6 +174,7 @@ class GNNProcessor(GraphEdgeMixin, BaseProcessor): def __init__( self, num_layers: int, + layer_kernels: DotDict, *args, trainable_size: int = 8, num_channels: int = 128, @@ -219,16 +219,15 @@ def __init__( self.trainable = TrainableTensor(trainable_size=trainable_size, tensor_size=self.edge_attr.shape[0]) kwargs = { - "num_layers": self.chunk_size, "mlp_extra_layers": mlp_extra_layers, "activation": activation, "edge_dim": None, } - self.build_layers(GNNProcessorChunk, num_channels, **kwargs) + self.build_layers(GNNProcessorChunk, num_channels, self.chunk_size, layer_kernels, **kwargs) kwargs["edge_dim"] = self.edge_dim # Edge dim for first layer - self.proc[0] = GNNProcessorChunk(num_channels, **kwargs) + self.proc[0] = GNNProcessorChunk(num_channels, self.chunk_size, layer_kernels, **kwargs) self.offload_layers(cpu_offload) @@ -263,6 +262,7 @@ class GraphTransformerProcessor(GraphEdgeMixin, BaseProcessor): def __init__( self, num_layers: int, + layer_kernels: DotDict, trainable_size: int = 8, num_channels: int = 128, num_chunks: int = 2, @@ -296,8 +296,8 @@ def __init__( Whether to offload processing to CPU, by default False """ super().__init__( - num_layers=num_layers, num_channels=num_channels, + num_layers=num_layers, num_chunks=num_chunks, activation=activation, cpu_offload=cpu_offload, @@ -313,6 +313,7 @@ def __init__( GraphTransformerProcessorChunk, num_channels=num_channels, num_layers=self.chunk_size, + layer_kernels=layer_kernels, num_heads=num_heads, mlp_hidden_ratio=mlp_hidden_ratio, activation=activation, From 84dfdfc3d58243f5df48e341e5f7e9a21dc02ff3 Mon Sep 17 00:00:00 2001 From: Jakob Schloer Date: Wed, 18 Dec 2024 13:51:55 +0000 Subject: [PATCH 09/15] Add type annotation. --- src/anemoi/models/layers/normalization.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/anemoi/models/layers/normalization.py b/src/anemoi/models/layers/normalization.py index 7665c9ea..d3193c4a 100644 --- a/src/anemoi/models/layers/normalization.py +++ b/src/anemoi/models/layers/normalization.py @@ -9,6 +9,7 @@ from __future__ import annotations +from torch import Tensor from torch import nn From ad78c59caf137bfb24886c7d918a1c18e20468e5 Mon Sep 17 00:00:00 2001 From: Jakob Schloer Date: Fri, 20 Dec 2024 10:55:26 +0000 Subject: [PATCH 10/15] Add tests for layer kernels and adapt tests --- src/anemoi/models/layers/utils.py | 40 +++++++++++ .../models/encoder_processor_decoder.py | 27 +------ tests/layers/block/test_block_graphconv.py | 6 ++ .../block/test_block_graphtransformer.py | 23 ++++-- tests/layers/block/test_block_transformer.py | 26 +++++-- tests/layers/chunk/test_chunk_gnn.py | 10 ++- .../chunk/test_chunk_graphtransformer.py | 7 +- tests/layers/chunk/test_chunk_transformer.py | 6 ++ tests/layers/mapper/test_graphconv_mapper.py | 34 ++++++++- .../mapper/test_graphtransformer_mapper.py | 35 ++++++++- .../processor/test_graphconv_processor.py | 8 +++ .../test_graphtransformer_processor.py | 8 +++ .../processor/test_transformer_processor.py | 8 +++ tests/layers/test_attention.py | 11 ++- tests/layers/test_mlp.py | 23 +++--- tests/layers/test_utils.py | 71 +++++++++++++++++++ 16 files changed, 291 insertions(+), 52 deletions(-) create mode 100644 tests/layers/test_utils.py diff --git a/src/anemoi/models/layers/utils.py b/src/anemoi/models/layers/utils.py index 6bec46aa..6be87a4e 100644 --- a/src/anemoi/models/layers/utils.py +++ b/src/anemoi/models/layers/utils.py @@ -8,9 +8,17 @@ # nor does it submit to any jurisdiction. +import logging +from typing import Optional + +from anemoi.utils.config import DotDict +from hydra.errors import InstantiationException +from hydra.utils import instantiate from torch import nn from torch.utils.checkpoint import checkpoint +LOGGER = logging.getLogger(__name__) + class CheckpointWrapper(nn.Module): """Wrapper for checkpointing a module.""" @@ -21,3 +29,35 @@ def __init__(self, module: nn.Module) -> None: def forward(self, *args, **kwargs): return checkpoint(self.module, *args, **kwargs, use_reentrant=False) + + +def load_layer_kernels(kernel_config: Optional[DotDict] = {}) -> DotDict: + """Load layer kernels from the config. + + Args: + kernel_config : Optional[DotDict] + Kernel configuration + + Returns: + DotDict: hydra partial instantiation of the layer kernels + """ + # If self.layer_kernels entry is missing from the config, use torch.nn kernels + default_kernels = { + "Linear": {"_target_": "torch.nn.Linear", "_partial_": True}, + "LayerNorm": {"_target_": "torch.nn.LayerNorm", "_partial_": True}, + } + layer_kernels = {**default_kernels, **kernel_config} + + # Loop through all kernels in the layer_kernels config entry and try import them + for kernel in layer_kernels: + kernel_entry = layer_kernels[kernel] + try: + instantiate(kernel_entry) + except InstantiationException: + LOGGER.info( + f"{kernel_entry['_target_']} not found! check your config.model.layer_kernel.{kernel} entry. Maybe your desired kernel is not installed or the import string is incorrect?" + ) + raise InstantiationException + else: + LOGGER.info(f"{kernel} kernel: {kernel_entry}") + return layer_kernels diff --git a/src/anemoi/models/models/encoder_processor_decoder.py b/src/anemoi/models/models/encoder_processor_decoder.py index 4c7a4406..011c272d 100644 --- a/src/anemoi/models/models/encoder_processor_decoder.py +++ b/src/anemoi/models/models/encoder_processor_decoder.py @@ -14,7 +14,6 @@ import einops import torch from anemoi.utils.config import DotDict -from hydra.errors import InstantiationException from hydra.utils import instantiate from omegaconf import OmegaConf from torch import Tensor @@ -25,6 +24,7 @@ from anemoi.models.distributed.shapes import get_shape_shards from anemoi.models.layers.graph import NamedNodesAttributes +from anemoi.models.layers.utils import load_layer_kernels LOGGER = logging.getLogger(__name__) @@ -68,7 +68,7 @@ def __init__( input_dim = self.multi_step * self.num_input_channels + self.node_attributes.attr_ndims[self._graph_name_data] # read config.model.layer_kernels to get the implementation for certain layers - self._load_layer_kernels(model_config) + self.layer_kernels = load_layer_kernels(OmegaConf.select(model_config, "model.layer_kernels")) # Encoder data -> hidden self.encoder = instantiate( @@ -239,26 +239,3 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) -> x_out = bounding(x_out) return x_out - - def _load_layer_kernels(self, config: DotDict) -> None: - - # If self.layer_kernels entry is missing from the config, use torch.nn kernels - default_kernels = { - "Linear": {"_target_": "torch.nn.Linear", "_partial_": True}, - "LayerNorm": {"_target_": "torch.nn.LayerNorm", "_partial_": True}, - } - user_kernel = OmegaConf.select(config, "model.layer_kernels") - self.layer_kernels = {**default_kernels, **user_kernel} - - # Loop through all kernels in the layer_kernels config entry and try import them - for kernel in self.layer_kernels: - kernel_entry = self.layer_kernels[kernel] - try: - instantiate(kernel_entry) - except InstantiationException: - LOGGER.info( - f"{kernel_entry['_target_']} not found! check your config.model.layer_kernel.{kernel} entry. Maybe your desired kernel is not installed or the import string is incorrect?" - ) - raise InstantiationException - else: - LOGGER.info(f"{kernel} kernel: {kernel_entry}") diff --git a/tests/layers/block/test_block_graphconv.py b/tests/layers/block/test_block_graphconv.py index fe89cb63..cf8573aa 100644 --- a/tests/layers/block/test_block_graphconv.py +++ b/tests/layers/block/test_block_graphconv.py @@ -8,6 +8,7 @@ # nor does it submit to any jurisdiction. +from hydra.utils import instantiate from hypothesis import given from hypothesis import settings from hypothesis import strategies as st @@ -16,6 +17,7 @@ from anemoi.models.layers.block import GraphConvMapperBlock from anemoi.models.layers.block import GraphConvProcessorBlock from anemoi.models.layers.conv import GraphConv +from anemoi.models.layers.utils import load_layer_kernels class TestGraphConvProcessorBlock: @@ -37,9 +39,11 @@ def test_init( update_src_nodes, num_chunks, ): + layer_kernels = instantiate(load_layer_kernels()) block = GraphConvProcessorBlock( in_channels=in_channels, out_channels=out_channels, + layer_kernels=layer_kernels, mlp_extra_layers=mlp_extra_layers, activation=activation, update_src_nodes=update_src_nodes, @@ -73,9 +77,11 @@ def test_init( update_src_nodes, num_chunks, ): + layer_kernels = instantiate(load_layer_kernels()) block = GraphConvMapperBlock( in_channels=in_channels, out_channels=out_channels, + layer_kernels=layer_kernels, mlp_extra_layers=mlp_extra_layers, activation=activation, update_src_nodes=update_src_nodes, diff --git a/tests/layers/block/test_block_graphtransformer.py b/tests/layers/block/test_block_graphtransformer.py index a6162046..cb24aa94 100644 --- a/tests/layers/block/test_block_graphtransformer.py +++ b/tests/layers/block/test_block_graphtransformer.py @@ -13,11 +13,13 @@ import pytest import torch import torch.nn as nn +from hydra.utils import instantiate import anemoi.models.layers.block from anemoi.models.layers.block import GraphTransformerMapperBlock from anemoi.models.layers.block import GraphTransformerProcessorBlock from anemoi.models.layers.conv import GraphTransformerConv +from anemoi.models.layers.utils import load_layer_kernels @pytest.fixture @@ -30,11 +32,13 @@ def init(): activation = "GELU" num_heads = 8 num_chunks = 2 + layer_kernels = instantiate(load_layer_kernels()) return ( in_channels, hidden_dim, out_channels, edge_dim, + layer_kernels, bias, activation, num_heads, @@ -49,6 +53,7 @@ def block(init): hidden_dim, out_channels, edge_dim, + layer_kernels, bias, activation, num_heads, @@ -59,6 +64,7 @@ def block(init): hidden_dim=hidden_dim, out_channels=out_channels, edge_dim=edge_dim, + layer_kernels=layer_kernels, num_heads=num_heads, bias=bias, activation=activation, @@ -73,6 +79,7 @@ def test_GraphTransformerProcessorBlock_init(init, block): _hidden_dim, out_channels, _edge_dim, + _layer_kernels, _bias, _activation, num_heads, @@ -96,9 +103,6 @@ def test_GraphTransformerProcessorBlock_init(init, block): assert isinstance( block.node_dst_mlp, torch.nn.Sequential ), "block.node_dst_mlp is not an instance of torch.nn.Sequential" - assert isinstance( - block.layer_norm1, torch.nn.LayerNorm - ), "block.layer_norm1 is not an instance of torch.nn.LayerNorm" def test_GraphTransformerProcessorBlock_shard_qkve_heads(init, block): @@ -107,6 +111,7 @@ def test_GraphTransformerProcessorBlock_shard_qkve_heads(init, block): _hidden_dim, _out_channels, _edge_dim, + _layer_kernels, _bias, _activation, num_heads, @@ -131,6 +136,7 @@ def test_GraphTransformerProcessorBlock_shard_output_seq(init, block): _hidden_dim, _out_channels, _edge_dim, + _layer_kernels, _bias, _activation, num_heads, @@ -150,6 +156,7 @@ def test_GraphTransformerProcessorBlock_forward_backward(init, block): _hidden_dim, out_channels, edge_dim, + _layer_kernels, _bias, _activation, _num_heads, @@ -194,6 +201,7 @@ def mapper_block(init): hidden_dim, out_channels, edge_dim, + layer_kernels, bias, activation, num_heads, @@ -204,6 +212,7 @@ def mapper_block(init): hidden_dim=hidden_dim, out_channels=out_channels, edge_dim=edge_dim, + layer_kernels=layer_kernels, num_heads=num_heads, bias=bias, activation=activation, @@ -218,6 +227,7 @@ def test_GraphTransformerMapperBlock_init(init, mapper_block): _hidden_dim, out_channels, _edge_dim, + _layer_kernels, _bias, _activation, num_heads, @@ -240,9 +250,6 @@ def test_GraphTransformerMapperBlock_init(init, mapper_block): assert isinstance( block.node_dst_mlp, torch.nn.Sequential ), "block.node_dst_mlp is not an instance of torch.nn.Sequential" - assert isinstance( - block.layer_norm1, torch.nn.LayerNorm - ), "block.layer_norm1 is not an instance of torch.nn.LayerNorm" def test_GraphTransformerMapperBlock_shard_qkve_heads(init, mapper_block): @@ -251,6 +258,7 @@ def test_GraphTransformerMapperBlock_shard_qkve_heads(init, mapper_block): _hidden_dim, _out_channels, _edge_dim, + _layer_kernels, _bias, _activation, num_heads, @@ -276,6 +284,7 @@ def test_GraphTransformerMapperBlock_shard_output_seq(init, mapper_block): _hidden_dim, _out_channels, _edge_dim, + _layer_kernels, _bias, _activation, num_heads, @@ -295,6 +304,7 @@ def test_GraphTransformerMapperBlock_forward_backward(init, mapper_block): _hidden_dim, out_channels, edge_dim, + _layer_kernels, _bias, _activation, _num_heads, @@ -342,6 +352,7 @@ def test_GraphTransformerMapperBlock_chunking(init, mapper_block, monkeypatch): _hidden_dim, _out_channels, edge_dim, + _layer_kernels, _bias, _activation, _num_heads, diff --git a/tests/layers/block/test_block_transformer.py b/tests/layers/block/test_block_transformer.py index 46541e08..a343d5d4 100644 --- a/tests/layers/block/test_block_transformer.py +++ b/tests/layers/block/test_block_transformer.py @@ -11,6 +11,7 @@ import logging import torch +from hydra.utils import instantiate from hypothesis import given from hypothesis import settings from hypothesis import strategies as st @@ -21,6 +22,7 @@ from anemoi.models.layers.block import GraphConvProcessorBlock from anemoi.models.layers.block import TransformerProcessorBlock from anemoi.models.layers.conv import GraphConv +from anemoi.models.layers.utils import load_layer_kernels LOGGER = logging.getLogger(__name__) @@ -37,13 +39,20 @@ class TestTransformerProcessorBlock: @settings(max_examples=10) def test_init(self, factor_attention_heads, hidden_dim, num_heads, activation, window_size, dropout_p): num_channels = num_heads * factor_attention_heads + layer_kernels = instantiate(load_layer_kernels()) block = TransformerProcessorBlock( - num_channels, hidden_dim, num_heads, activation, window_size, dropout_p=dropout_p + num_channels, + hidden_dim, + num_heads, + activation, + window_size, + layer_kernels=layer_kernels, + dropout_p=dropout_p, ) assert isinstance(block, TransformerProcessorBlock) - assert isinstance(block.layer_norm1, nn.LayerNorm) - assert isinstance(block.layer_norm2, nn.LayerNorm) + assert isinstance(block.layer_norm_attention, nn.LayerNorm) + assert isinstance(block.layer_norm_mlp, nn.LayerNorm) assert isinstance(block.mlp, nn.Sequential) assert isinstance(block.attention, MultiHeadSelfAttention) @@ -70,8 +79,15 @@ def test_forward_output( dropout_p, ): num_channels = num_heads * factor_attention_heads + layer_kernels = instantiate(load_layer_kernels()) block = TransformerProcessorBlock( - num_channels, hidden_dim, num_heads, activation, window_size, dropout_p=dropout_p + num_channels, + hidden_dim, + num_heads, + activation, + window_size, + layer_kernels=layer_kernels, + dropout_p=dropout_p, ) x = torch.randn((batch_size, num_channels)) @@ -100,9 +116,11 @@ def test_init( update_src_nodes, num_chunks, ): + layer_kernels = instantiate(load_layer_kernels()) block = GraphConvProcessorBlock( in_channels=in_channels, out_channels=out_channels, + layer_kernels=layer_kernels, mlp_extra_layers=mlp_extra_layers, activation=activation, update_src_nodes=update_src_nodes, diff --git a/tests/layers/chunk/test_chunk_gnn.py b/tests/layers/chunk/test_chunk_gnn.py index f8ac8a12..993c2c7d 100644 --- a/tests/layers/chunk/test_chunk_gnn.py +++ b/tests/layers/chunk/test_chunk_gnn.py @@ -9,10 +9,12 @@ import pytest +from hydra.utils import instantiate from anemoi.models.layers.block import GraphConvProcessorBlock from anemoi.models.layers.chunk import GNNProcessorChunk from anemoi.models.layers.mlp import MLP +from anemoi.models.layers.utils import load_layer_kernels class TestGNNProcessorChunk: @@ -22,21 +24,23 @@ def init(self): num_layers = 3 mlp_extra_layers = 3 edge_dim = None - return num_channels, num_layers, mlp_extra_layers, edge_dim + layer_kernels = instantiate(load_layer_kernels()) + return num_channels, num_layers, layer_kernels, mlp_extra_layers, edge_dim @pytest.fixture def processor_chunk(self, init): - num_channels, num_layers, mlp_extra_layers, edge_dim = init + num_channels, num_layers, layer_kernels, mlp_extra_layers, edge_dim = init return GNNProcessorChunk( num_channels=num_channels, num_layers=num_layers, + layer_kernels=layer_kernels, mlp_extra_layers=mlp_extra_layers, activation="SiLU", edge_dim=edge_dim, ) def test_embed_edges(self, init, processor_chunk): - _num_channels, _num_layers, _mlp_extra_layers, edge_dim = init + _num_channels, _num_layers, _layer_kernels, _mlp_extra_layers, edge_dim = init if edge_dim: assert isinstance(processor_chunk.emb_edges, MLP) else: diff --git a/tests/layers/chunk/test_chunk_graphtransformer.py b/tests/layers/chunk/test_chunk_graphtransformer.py index 2f93cc91..ec4ae304 100644 --- a/tests/layers/chunk/test_chunk_graphtransformer.py +++ b/tests/layers/chunk/test_chunk_graphtransformer.py @@ -9,9 +9,11 @@ import pytest +from hydra.utils import instantiate from anemoi.models.layers.block import GraphTransformerProcessorBlock from anemoi.models.layers.chunk import GraphTransformerProcessorChunk +from anemoi.models.layers.utils import load_layer_kernels class TestGraphTransformerProcessorChunk: @@ -23,9 +25,11 @@ def init(self): mlp_hidden_ratio: int = 4 activation: str = "GELU" edge_dim: int = 32 + layer_kernels = instantiate(load_layer_kernels()) return ( num_channels, num_layers, + layer_kernels, num_heads, mlp_hidden_ratio, activation, @@ -34,10 +38,11 @@ def init(self): @pytest.fixture def processor_chunk(self, init): - num_channels, num_layers, num_heads, mlp_hidden_ratio, activation, edge_dim = init + num_channels, num_layers, layer_kernels, num_heads, mlp_hidden_ratio, activation, edge_dim = init return GraphTransformerProcessorChunk( num_channels=num_channels, num_layers=num_layers, + layer_kernels=layer_kernels, num_heads=num_heads, mlp_hidden_ratio=mlp_hidden_ratio, activation=activation, diff --git a/tests/layers/chunk/test_chunk_transformer.py b/tests/layers/chunk/test_chunk_transformer.py index 86989486..b9d758d3 100644 --- a/tests/layers/chunk/test_chunk_transformer.py +++ b/tests/layers/chunk/test_chunk_transformer.py @@ -9,9 +9,11 @@ import pytest +from hydra.utils import instantiate from anemoi.models.layers.block import TransformerProcessorBlock from anemoi.models.layers.chunk import TransformerProcessorChunk +from anemoi.models.layers.utils import load_layer_kernels class TestGraphTransformerProcessorChunk: @@ -24,11 +26,13 @@ def init(self): activation: str = "GELU" window_size: int = 13 dropout_p: float = 0.1 + layer_kernels = instantiate(load_layer_kernels()) # num_heads must be evenly divisible by num_channels for MHSA return ( num_channels, num_layers, + layer_kernels, num_heads, mlp_hidden_ratio, activation, @@ -41,6 +45,7 @@ def processor_chunk(self, init): ( num_channels, num_layers, + layer_kernels, num_heads, mlp_hidden_ratio, activation, @@ -50,6 +55,7 @@ def processor_chunk(self, init): return TransformerProcessorChunk( num_channels=num_channels, num_layers=num_layers, + layer_kernels=layer_kernels, num_heads=num_heads, mlp_hidden_ratio=mlp_hidden_ratio, activation=activation, diff --git a/tests/layers/mapper/test_graphconv_mapper.py b/tests/layers/mapper/test_graphconv_mapper.py index 1a756989..9d1a5d84 100644 --- a/tests/layers/mapper/test_graphconv_mapper.py +++ b/tests/layers/mapper/test_graphconv_mapper.py @@ -10,12 +10,15 @@ import pytest import torch +from hydra.utils import instantiate +from omegaconf import OmegaConf from torch import nn from torch_geometric.data import HeteroData from anemoi.models.layers.mapper import GNNBackwardMapper from anemoi.models.layers.mapper import GNNBaseMapper from anemoi.models.layers.mapper import GNNForwardMapper +from anemoi.models.layers.utils import load_layer_kernels class TestGNNBaseMapper: @@ -26,7 +29,21 @@ class TestGNNBaseMapper: NUM_EDGES: int = 300 @pytest.fixture - def mapper_init(self): + def layer_kernels(self): + kernel_config = OmegaConf.create( + { + "LayerNorm": { + "_target_": "torch.nn.LayerNorm", + "_partial_": True, + }, + "Linear": {"_target_": "torch.nn.Linear", "_partial_": True, "bias": False}, + } + ) + layer_kernels = load_layer_kernels(kernel_config) + return instantiate(layer_kernels) + + @pytest.fixture + def mapper_init(self, layer_kernels): in_channels_src: int = 3 in_channels_dst: int = 4 hidden_dim: int = 256 @@ -42,6 +59,7 @@ def mapper_init(self): cpu_offload, activation, trainable_size, + layer_kernels, ) @pytest.fixture @@ -54,6 +72,7 @@ def mapper(self, mapper_init, fake_graph): cpu_offload, activation, trainable_size, + layer_kernels, ) = mapper_init return GNNBaseMapper( in_channels_src=in_channels_src, @@ -65,6 +84,7 @@ def mapper(self, mapper_init, fake_graph): sub_graph=fake_graph[("src", "to", "dst")], sub_graph_edge_attributes=["edge_attr1", "edge_attr2"], trainable_size=trainable_size, + layer_kernels=layer_kernels, ) @pytest.fixture @@ -77,6 +97,7 @@ def pair_tensor(self, mapper_init): _cpu_offload, _activation, _trainable_size, + _layer_kernels, ) = mapper_init return ( torch.rand(self.NUM_SRC_NODES, in_channels_src), @@ -107,6 +128,7 @@ def test_initialization(self, mapper, mapper_init): _cpu_offload, activation, _trainable_size, + _layer_kernels, ) = mapper_init assert isinstance(mapper, GNNBaseMapper) assert mapper.in_channels_src == in_channels_src @@ -126,6 +148,7 @@ def test_pre_process(self, mapper, mapper_init, pair_tensor): _cpu_offload, _activation, _trainable_size, + _layer_kernels, ) = mapper_init shard_shapes = [list(x[0].shape)], [list(x[1].shape)] @@ -165,6 +188,7 @@ def mapper(self, mapper_init, fake_graph): cpu_offload, activation, trainable_size, + layer_kernels, ) = mapper_init return GNNForwardMapper( in_channels_src=in_channels_src, @@ -176,6 +200,7 @@ def mapper(self, mapper_init, fake_graph): sub_graph=fake_graph[("src", "to", "dst")], sub_graph_edge_attributes=["edge_attr1", "edge_attr2"], trainable_size=trainable_size, + layer_kernels=layer_kernels, ) def test_pre_process(self, mapper, mapper_init, pair_tensor): @@ -188,6 +213,7 @@ def test_pre_process(self, mapper, mapper_init, pair_tensor): _cpu_offload, _activation, _trainable_size, + _layer_kernels, ) = mapper_init shard_shapes = [list(x[0].shape)], [list(x[1].shape)] @@ -212,6 +238,7 @@ def test_forward_backward(self, mapper_init, mapper, pair_tensor): _cpu_offload, _activation, _trainable_size, + _layer_kernels, ) = mapper_init x = pair_tensor batch_size = 1 @@ -256,6 +283,7 @@ def mapper(self, mapper_init, fake_graph): cpu_offload, activation, trainable_size, + layer_kernels, ) = mapper_init return GNNBackwardMapper( in_channels_src=in_channels_src, @@ -267,6 +295,7 @@ def mapper(self, mapper_init, fake_graph): sub_graph=fake_graph[("src", "to", "dst")], sub_graph_edge_attributes=["edge_attr1", "edge_attr2"], trainable_size=trainable_size, + layer_kernels=layer_kernels, ) def test_pre_process(self, mapper, mapper_init, pair_tensor): @@ -279,6 +308,7 @@ def test_pre_process(self, mapper, mapper_init, pair_tensor): _cpu_offload, _activation, _trainable_size, + _layer_kernels, ) = mapper_init shard_shapes = [list(x[0].shape)], [list(x[1].shape)] @@ -303,6 +333,7 @@ def test_post_process(self, mapper, mapper_init): _cpu_offload, _activation, _trainable_size, + _layer_kernels, ) = mapper_init x_dst = torch.rand(self.NUM_DST_NODES, hidden_dim) shapes_dst = [list(x_dst.shape)] @@ -321,6 +352,7 @@ def test_forward_backward(self, mapper_init, mapper, pair_tensor): _cpu_offload, _activation, _trainable_size, + _layer_kernels, ) = mapper_init pair_tensor shard_shapes = [list(pair_tensor[0].shape)], [list(pair_tensor[1].shape)] diff --git a/tests/layers/mapper/test_graphtransformer_mapper.py b/tests/layers/mapper/test_graphtransformer_mapper.py index dece0e22..ba331dc3 100644 --- a/tests/layers/mapper/test_graphtransformer_mapper.py +++ b/tests/layers/mapper/test_graphtransformer_mapper.py @@ -10,12 +10,15 @@ import pytest import torch +from hydra.utils import instantiate +from omegaconf import OmegaConf from torch import nn from torch_geometric.data import HeteroData from anemoi.models.layers.mapper import GraphTransformerBackwardMapper from anemoi.models.layers.mapper import GraphTransformerBaseMapper from anemoi.models.layers.mapper import GraphTransformerForwardMapper +from anemoi.models.layers.utils import load_layer_kernels class TestGraphTransformerBaseMapper: @@ -26,7 +29,21 @@ class TestGraphTransformerBaseMapper: NUM_DST_NODES: int = 200 @pytest.fixture - def mapper_init(self): + def layer_kernels(self): + kernel_config = OmegaConf.create( + { + "LayerNorm": { + "_target_": "torch.nn.LayerNorm", + "_partial_": True, + }, + "Linear": {"_target_": "torch.nn.Linear", "_partial_": True, "bias": False}, + } + ) + layer_kernels = load_layer_kernels(kernel_config) + return instantiate(layer_kernels) + + @pytest.fixture + def mapper_init(self, layer_kernels): in_channels_src: int = 3 in_channels_dst: int = 3 hidden_dim: int = 256 @@ -46,6 +63,7 @@ def mapper_init(self): trainable_size, num_heads, mlp_hidden_ratio, + layer_kernels, ) @pytest.fixture @@ -60,6 +78,7 @@ def mapper(self, mapper_init, fake_graph): trainable_size, num_heads, mlp_hidden_ratio, + layer_kernels, ) = mapper_init return GraphTransformerBaseMapper( in_channels_src=in_channels_src, @@ -73,6 +92,7 @@ def mapper(self, mapper_init, fake_graph): trainable_size=trainable_size, num_heads=num_heads, mlp_hidden_ratio=mlp_hidden_ratio, + layer_kernels=layer_kernels, ) @pytest.fixture @@ -87,6 +107,7 @@ def pair_tensor(self, mapper_init): _trainable_size, _num_heads, _mlp_hidden_ratio, + _layer_kernels, ) = mapper_init return ( torch.rand(self.NUM_SRC_NODES, in_channels_src), @@ -119,6 +140,7 @@ def test_initialization(self, mapper, mapper_init): _trainable_size, _num_heads, _mlp_hidden_ratio, + _layer_kernels, ) = mapper_init assert isinstance(mapper, GraphTransformerBaseMapper) assert mapper.in_channels_src == in_channels_src @@ -126,6 +148,7 @@ def test_initialization(self, mapper, mapper_init): assert mapper.hidden_dim == hidden_dim assert mapper.out_channels_dst == out_channels_dst assert mapper.activation == activation + assert mapper.emb_nodes_dst.bias is None def test_pre_process(self, mapper, mapper_init, pair_tensor): # Should be a no-op in the base class @@ -140,6 +163,7 @@ def test_pre_process(self, mapper, mapper_init, pair_tensor): _trainable_size, _num_heads, _mlp_hidden_ratio, + _layer_kernels, ) = mapper_init shard_shapes = [list(x[0].shape)], [list(x[1].shape)] @@ -181,6 +205,7 @@ def mapper(self, mapper_init, fake_graph): trainable_size, num_heads, mlp_hidden_ratio, + layer_kernels, ) = mapper_init return GraphTransformerForwardMapper( in_channels_src=in_channels_src, @@ -194,6 +219,7 @@ def mapper(self, mapper_init, fake_graph): trainable_size=trainable_size, num_heads=num_heads, mlp_hidden_ratio=mlp_hidden_ratio, + layer_kernels=layer_kernels, ) def test_pre_process(self, mapper, mapper_init, pair_tensor): @@ -208,6 +234,7 @@ def test_pre_process(self, mapper, mapper_init, pair_tensor): _trainable_size, _num_heads, _mlp_hidden_ratio, + _layer_kernels, ) = mapper_init shard_shapes = [list(x[0].shape)], [list(x[1].shape)] @@ -234,6 +261,7 @@ def test_forward_backward(self, mapper_init, mapper, pair_tensor): _trainable_size, _num_heads, _mlp_hidden_ratio, + _layer_kernels, ) = mapper_init x = pair_tensor batch_size = 1 @@ -280,6 +308,7 @@ def mapper(self, mapper_init, fake_graph): trainable_size, _num_heads, _mlp_hidden_ratio, + layer_kernels, ) = mapper_init return GraphTransformerBackwardMapper( in_channels_src=in_channels_src, @@ -291,6 +320,7 @@ def mapper(self, mapper_init, fake_graph): sub_graph=fake_graph[("src", "to", "dst")], sub_graph_edge_attributes=["edge_attr1", "edge_attr2"], trainable_size=trainable_size, + layer_kernels=layer_kernels, ) def test_pre_process(self, mapper, mapper_init, pair_tensor): @@ -305,6 +335,7 @@ def test_pre_process(self, mapper, mapper_init, pair_tensor): _trainable_size, _num_heads, _mlp_hidden_ratio, + _layer_kernels, ) = mapper_init shard_shapes = [list(x[0].shape)], [list(x[1].shape)] @@ -331,6 +362,7 @@ def test_post_process(self, mapper, mapper_init): _trainable_size, _num_heads, _mlp_hidden_ratio, + _layer_kernels, ) = mapper_init x_dst = torch.rand(self.NUM_DST_NODES, hidden_dim) shapes_dst = [list(x_dst.shape)] @@ -351,6 +383,7 @@ def test_forward_backward(self, mapper_init, mapper, pair_tensor): _trainable_size, _num_heads, _mlp_hidden_ratio, + _layer_kernels, ) = mapper_init pair_tensor shard_shapes = [list(pair_tensor[0].shape)], [list(pair_tensor[1].shape)] diff --git a/tests/layers/processor/test_graphconv_processor.py b/tests/layers/processor/test_graphconv_processor.py index e847d64e..a13d440e 100644 --- a/tests/layers/processor/test_graphconv_processor.py +++ b/tests/layers/processor/test_graphconv_processor.py @@ -10,10 +10,12 @@ import pytest import torch +from hydra.utils import instantiate from torch_geometric.data import HeteroData from anemoi.models.layers.graph import TrainableTensor from anemoi.models.layers.processor import GNNProcessor +from anemoi.models.layers.utils import load_layer_kernels class TestGNNProcessor: @@ -44,8 +46,10 @@ def graphconv_init(self, fake_graph: HeteroData): src_grid_size = 0 dst_grid_size = 0 trainable_size = 8 + layer_kernels = instantiate(load_layer_kernels()) return ( num_layers, + layer_kernels, num_channels, num_chunks, mlp_extra_layers, @@ -62,6 +66,7 @@ def graphconv_init(self, fake_graph: HeteroData): def graphconv_processor(self, graphconv_init): ( num_layers, + layer_kernels, num_channels, num_chunks, mlp_extra_layers, @@ -75,6 +80,7 @@ def graphconv_processor(self, graphconv_init): ) = graphconv_init return GNNProcessor( num_layers, + layer_kernels, num_channels=num_channels, num_chunks=num_chunks, mlp_extra_layers=mlp_extra_layers, @@ -90,6 +96,7 @@ def graphconv_processor(self, graphconv_init): def test_graphconv_processor_init(self, graphconv_processor, graphconv_init): ( num_layers, + _layer_kernels, num_channels, num_chunks, _mlp_extra_layers, @@ -110,6 +117,7 @@ def test_forward(self, graphconv_processor, graphconv_init): batch_size = 1 ( _num_layers, + _layer_kernels, num_channels, _num_chunks, _mlp_extra_layers, diff --git a/tests/layers/processor/test_graphtransformer_processor.py b/tests/layers/processor/test_graphtransformer_processor.py index 95ba1c45..6c80bd04 100644 --- a/tests/layers/processor/test_graphtransformer_processor.py +++ b/tests/layers/processor/test_graphtransformer_processor.py @@ -10,10 +10,12 @@ import pytest import torch +from hydra.utils import instantiate from torch_geometric.data import HeteroData from anemoi.models.layers.graph import TrainableTensor from anemoi.models.layers.processor import GraphTransformerProcessor +from anemoi.models.layers.utils import load_layer_kernels class TestGraphTransformerProcessor: @@ -45,8 +47,10 @@ def graphtransformer_init(self, fake_graph: HeteroData): src_grid_size = 0 dst_grid_size = 0 trainable_size = 6 + layer_kernels = instantiate(load_layer_kernels()) return ( num_layers, + layer_kernels, num_channels, num_chunks, num_heads, @@ -64,6 +68,7 @@ def graphtransformer_init(self, fake_graph: HeteroData): def graphtransformer_processor(self, graphtransformer_init): ( num_layers, + layer_kernels, num_channels, num_chunks, num_heads, @@ -78,6 +83,7 @@ def graphtransformer_processor(self, graphtransformer_init): ) = graphtransformer_init return GraphTransformerProcessor( num_layers, + layer_kernels, num_channels=num_channels, num_chunks=num_chunks, num_heads=num_heads, @@ -94,6 +100,7 @@ def graphtransformer_processor(self, graphtransformer_init): def test_graphtransformer_processor_init(self, graphtransformer_processor, graphtransformer_init): ( num_layers, + _layer_kernels, num_channels, num_chunks, _num_heads, @@ -115,6 +122,7 @@ def test_forward(self, graphtransformer_processor, graphtransformer_init): batch_size = 1 ( _num_layers, + _layer_kernels, num_channels, _num_chunks, _num_heads, diff --git a/tests/layers/processor/test_transformer_processor.py b/tests/layers/processor/test_transformer_processor.py index b94ff63f..779b2047 100644 --- a/tests/layers/processor/test_transformer_processor.py +++ b/tests/layers/processor/test_transformer_processor.py @@ -10,8 +10,10 @@ import pytest import torch +from hydra.utils import instantiate from anemoi.models.layers.processor import TransformerProcessor +from anemoi.models.layers.utils import load_layer_kernels @pytest.fixture @@ -25,8 +27,10 @@ def transformer_processor_init(): num_heads = 16 mlp_hidden_ratio = 4 dropout_p = 0.1 + layer_kernels = instantiate(load_layer_kernels()) return ( num_layers, + layer_kernels, window_size, num_channels, num_chunks, @@ -42,6 +46,7 @@ def transformer_processor_init(): def transformer_processor(transformer_processor_init): ( num_layers, + layer_kernels, window_size, num_channels, num_chunks, @@ -53,6 +58,7 @@ def transformer_processor(transformer_processor_init): ) = transformer_processor_init return TransformerProcessor( num_layers=num_layers, + layer_kernels=layer_kernels, window_size=window_size, num_channels=num_channels, num_chunks=num_chunks, @@ -67,6 +73,7 @@ def transformer_processor(transformer_processor_init): def test_transformer_processor_init(transformer_processor, transformer_processor_init): ( num_layers, + _layer_kernels, _window_size, num_channels, num_chunks, @@ -85,6 +92,7 @@ def test_transformer_processor_init(transformer_processor, transformer_processor def test_transformer_processor_forward(transformer_processor, transformer_processor_init): ( _num_layers, + _layer_kernels, _window_size, num_channels, _num_chunks, diff --git a/tests/layers/test_attention.py b/tests/layers/test_attention.py index a1b40540..b3da6cad 100644 --- a/tests/layers/test_attention.py +++ b/tests/layers/test_attention.py @@ -12,10 +12,12 @@ import pytest import torch import torch.nn as nn +from hydra.utils import instantiate from hypothesis import given from hypothesis import settings from anemoi.models.layers.attention import MultiHeadSelfAttention +from anemoi.models.layers.utils import load_layer_kernels @given( @@ -27,7 +29,8 @@ def test_multi_head_self_attention_init(num_heads, embed_dim_multiplier, dropout embed_dim = ( num_heads * embed_dim_multiplier ) # TODO: Make assert in MHSA to check if embed_dim is divisible by num_heads - mhsa = MultiHeadSelfAttention(num_heads, embed_dim, dropout_p=dropout_p) + layer_kernels = instantiate(load_layer_kernels()) + mhsa = MultiHeadSelfAttention(num_heads, embed_dim, layer_kernels, dropout_p=dropout_p) assert isinstance(mhsa, nn.Module) assert mhsa.num_heads == num_heads @@ -46,7 +49,8 @@ def test_multi_head_self_attention_init(num_heads, embed_dim_multiplier, dropout @settings(deadline=None) def test_multi_head_self_attention_forward(batch_size, num_heads, embed_dim_multiplier, dropout_p): embed_dim = num_heads * embed_dim_multiplier - mhsa = MultiHeadSelfAttention(num_heads, embed_dim, dropout_p=dropout_p) + layer_kernels = instantiate(load_layer_kernels()) + mhsa = MultiHeadSelfAttention(num_heads, embed_dim, layer_kernels, dropout_p=dropout_p) x = torch.randn(batch_size * 2, embed_dim) shapes = [list(x.shape)] @@ -64,7 +68,8 @@ def test_multi_head_self_attention_forward(batch_size, num_heads, embed_dim_mult ) def test_multi_head_self_attention_backward(batch_size, num_heads, embed_dim_multiplier, dropout_p): embed_dim = num_heads * embed_dim_multiplier - mhsa = MultiHeadSelfAttention(num_heads, embed_dim, dropout_p=dropout_p) + layer_kernels = instantiate(load_layer_kernels()) + mhsa = MultiHeadSelfAttention(num_heads, embed_dim, layer_kernels, dropout_p=dropout_p) x = torch.randn(batch_size * 2, embed_dim, requires_grad=True) shapes = [list(x.shape)] diff --git a/tests/layers/test_mlp.py b/tests/layers/test_mlp.py index a5e93892..a3d93d03 100644 --- a/tests/layers/test_mlp.py +++ b/tests/layers/test_mlp.py @@ -10,8 +10,10 @@ import pytest import torch +from hydra.utils import instantiate from anemoi.models.layers.mlp import MLP +from anemoi.models.layers.utils import load_layer_kernels @pytest.fixture @@ -39,24 +41,29 @@ def num_out_feature(): return 36 +@pytest.fixture +def layer_kernels(): + return instantiate(load_layer_kernels()) + + class TestMLP: - def test_init(self, num_features, hdim, num_out_feature): + def test_init(self, num_features, hdim, num_out_feature, layer_kernels): """Test MLP initialization.""" - mlp = MLP(num_features, hdim, num_out_feature, 0, "SiLU") + mlp = MLP(num_features, hdim, num_out_feature, layer_kernels, n_extra_layers=0, activation="SiLU") assert isinstance(mlp, MLP) assert isinstance(mlp.model, torch.nn.Sequential) assert len(mlp.model) == 6 - mlp = MLP(num_features, hdim, num_out_feature, 0, "ReLU", False, False, False) + mlp = MLP(num_features, hdim, num_out_feature, layer_kernels, 0, "ReLU", False, False, False) assert len(mlp.model) == 5 - mlp = MLP(num_features, hdim, num_out_feature, 1, "SiLU", False, False, False) + mlp = MLP(num_features, hdim, num_out_feature, layer_kernels, 1, "SiLU", False, False, False) assert len(mlp.model) == 7 - def test_forwards(self, batch_size, nlatlon, num_features, hdim, num_out_feature): + def test_forwards(self, batch_size, nlatlon, num_features, hdim, num_out_feature, layer_kernels): """Test MLP forward pass.""" - mlp = MLP(num_features, hdim, num_out_feature, layer_norm=True) + mlp = MLP(num_features, hdim, num_out_feature, layer_kernels=layer_kernels, layer_norm=True) x_in = torch.randn((batch_size, nlatlon, num_features), dtype=torch.float32, requires_grad=True) out = mlp(x_in) @@ -66,11 +73,11 @@ def test_forwards(self, batch_size, nlatlon, num_features, hdim, num_out_feature num_out_feature, ), "Output shape is not correct" - def test_backward(self, batch_size, nlatlon, num_features, hdim): + def test_backward(self, batch_size, nlatlon, num_features, hdim, layer_kernels): """Test MLP backward pass.""" x_in = torch.randn((batch_size, nlatlon, num_features), dtype=torch.float32, requires_grad=True) - mlp_1 = MLP(num_features, hdim, hdim, layer_norm=True) + mlp_1 = MLP(num_features, hdim, hdim, layer_kernels, layer_norm=True) y = mlp_1(x_in) assert y.shape == (batch_size, nlatlon, hdim) diff --git a/tests/layers/test_utils.py b/tests/layers/test_utils.py new file mode 100644 index 00000000..1984ea62 --- /dev/null +++ b/tests/layers/test_utils.py @@ -0,0 +1,71 @@ +import pytest +import torch +from hydra.errors import InstantiationException +from hydra.utils import instantiate +from omegaconf import OmegaConf + +from anemoi.models.layers.utils import load_layer_kernels + + +@pytest.fixture +def default_layer_kernels(): + # Default layer kernels + kernels_config = OmegaConf.create( + { + "LayerNorm": { + "_target_": "torch.nn.LayerNorm", + "_partial_": True, + }, + } + ) + return instantiate(load_layer_kernels(kernels_config)) + + +@pytest.fixture +def custom_layer_kernels(): + # Custom layer kernels + kernels_config = OmegaConf.create( + { + "LayerNorm": { + "_target_": "torch.nn.LayerNorm", + "_partial_": True, + "eps": 1e-3, + "elementwise_affine": False, + }, + "Linear": {"_target_": "torch.nn.Linear", "_partial_": True, "bias": False}, + } + ) + return instantiate(load_layer_kernels(kernels_config)) + + +def test_kernels_init(default_layer_kernels): + """Test that the layer kernels are instantiated.""" + channels = 10 + linear_layer = default_layer_kernels["Linear"](in_features=channels, out_features=channels) + layer_norm = default_layer_kernels["LayerNorm"](normalized_shape=channels) + assert isinstance(linear_layer, torch.nn.Linear) + assert isinstance(layer_norm, torch.nn.LayerNorm) + assert linear_layer.bias.shape == torch.Size([channels]) + assert layer_norm.bias.shape == torch.Size([channels]) + + +def test_custom_kernels(custom_layer_kernels): + """Test that the custom layer kernels are instantiated.""" + linear_layer = custom_layer_kernels["Linear"](in_features=10, out_features=10) + layer_norm = custom_layer_kernels["LayerNorm"](normalized_shape=10) + + assert linear_layer.bias is None + assert layer_norm.bias is None + + +def test_unavailable_kernel(): + """Config with an unavailable kernel that should raise an error.""" + kernels_config = OmegaConf.create( + { + "LayerNorm": {"_target_": "nonexistent_package.LayerNorm", "_partial_": True}, + "Linear": {"_target_": "torch.nn.Linear", "_partial_": True}, + } + ) + # Catch InstantiationException + with pytest.raises(InstantiationException): + load_layer_kernels(kernels_config) From f7d70d900f11489a165a1a5f7ada92b89111c7e0 Mon Sep 17 00:00:00 2001 From: Jakob Schloer Date: Fri, 20 Dec 2024 11:45:56 +0000 Subject: [PATCH 11/15] Add layer kernel to config. --- training/src/anemoi/training/config/model/gnn.yaml | 11 +++++++++++ .../training/config/model/graphtransformer.yaml | 11 +++++++++++ .../src/anemoi/training/config/model/transformer.yaml | 11 +++++++++++ 3 files changed, 33 insertions(+) diff --git a/training/src/anemoi/training/config/model/gnn.yaml b/training/src/anemoi/training/config/model/gnn.yaml index 92a17fd4..00b04500 100644 --- a/training/src/anemoi/training/config/model/gnn.yaml +++ b/training/src/anemoi/training/config/model/gnn.yaml @@ -4,6 +4,17 @@ num_channels: 512 model: _target_: anemoi.models.models.encoder_processor_decoder.AnemoiModelEncProcDec +layer_kernels: + LayerNorm: + _target_: torch.nn.LayerNorm #the default PyTorch implementation + _partial_: True + #Any arguments to your chosen function go here e.g. + Linear: + #_target_: "transformer_engine.pytorch.Linear" + _target_: torch.nn.Linear + _partial_: True + #Any arguments to your chosen function go here e.g. + processor: _target_: anemoi.models.layers.processor.GNNProcessor _convert_: all diff --git a/training/src/anemoi/training/config/model/graphtransformer.yaml b/training/src/anemoi/training/config/model/graphtransformer.yaml index 9c48967b..63e4634b 100644 --- a/training/src/anemoi/training/config/model/graphtransformer.yaml +++ b/training/src/anemoi/training/config/model/graphtransformer.yaml @@ -4,6 +4,17 @@ num_channels: 1024 model: _target_: anemoi.models.models.encoder_processor_decoder.AnemoiModelEncProcDec +layer_kernels: + LayerNorm: + _target_: torch.nn.LayerNorm #the default PyTorch implementation + _partial_: True + #Any arguments to your chosen function go here e.g. + Linear: + #_target_: "transformer_engine.pytorch.Linear" + _target_: torch.nn.Linear + _partial_: True + #Any arguments to your chosen function go here e.g. + processor: _target_: anemoi.models.layers.processor.GraphTransformerProcessor _convert_: all diff --git a/training/src/anemoi/training/config/model/transformer.yaml b/training/src/anemoi/training/config/model/transformer.yaml index cd6a1e7b..142485f2 100644 --- a/training/src/anemoi/training/config/model/transformer.yaml +++ b/training/src/anemoi/training/config/model/transformer.yaml @@ -4,6 +4,17 @@ num_channels: 1024 model: _target_: anemoi.models.models.encoder_processor_decoder.AnemoiModelEncProcDec +layer_kernels: + LayerNorm: + _target_: torch.nn.LayerNorm #the default PyTorch implementation + _partial_: True + #Any arguments to your chosen function go here e.g. + Linear: + #_target_: "transformer_engine.pytorch.Linear" + _target_: torch.nn.Linear + _partial_: True + #Any arguments to your chosen function go here e.g. + processor: _target_: anemoi.models.layers.processor.TransformerProcessor _convert_: all From 022e34dc378cbd4dc174d66880d70fb6f6799638 Mon Sep 17 00:00:00 2001 From: Jakob Schloer Date: Fri, 20 Dec 2024 11:59:22 +0000 Subject: [PATCH 12/15] Add conditional LN. Implementation originally written by @ssmmnn11 --- .../src/anemoi/models/layers/normalization.py | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/models/src/anemoi/models/layers/normalization.py b/models/src/anemoi/models/layers/normalization.py index d3193c4a..cfd99c7a 100644 --- a/models/src/anemoi/models/layers/normalization.py +++ b/models/src/anemoi/models/layers/normalization.py @@ -9,6 +9,10 @@ from __future__ import annotations +from typing import List +from typing import Union + +from torch import Size from torch import Tensor from torch import nn @@ -26,3 +30,48 @@ def forward(self, x: Tensor) -> Tensor: precision. """ return super().forward(x).type_as(x) + + +class ConditionalLayerNorm(nn.Module): + """Conditional Layer Normalization. + + x_norm = a(u) * (x - mean) / sqrt(var + eps) + b(u) + + """ + + def __init__( + self, + normalized_shape: Union[int, list, Size], + condition_shape: int = 16, + w_one_bias_zero_init: bool = True, + autocast: bool = True, + ): + super().__init__() + self.norm = nn.LayerNorm(normalized_shape, elementwise_affine=False) # no learnable parameters + self.scale = nn.Linear(condition_shape, normalized_shape) # , bias=False) + self.bias = nn.Linear(condition_shape, normalized_shape) # , bias=False) + self.autocast = autocast + + if w_one_bias_zero_init: + nn.init.ones_(self.scale.weight) + nn.init.zeros_(self.scale.bias) + nn.init.zeros_(self.bias.weight) + nn.init.zeros_(self.bias.bias) + + def forward(self, input: List[Tensor, Tensor]) -> Tensor: + """Conditional Layer Normalization. + + Args: + input (List[Tensor, Tensor]): A list of two tensors (x, cond), + the first is the input tensor and + the second is the condition tensor. + + Returns: + Tensor: The output tensor. + """ + x, cond = input + scale = self.scale(cond) + bias = self.bias(cond) + out = self.norm(x) + out = out * (scale + 1.0) + bias + return out.type_as(x) if self.autocast else out From cac4ad4e83bf5637413910cc4ef2f7eab5b9cf1a Mon Sep 17 00:00:00 2001 From: Jakob Schloer Date: Wed, 8 Jan 2025 16:38:18 +0000 Subject: [PATCH 13/15] Including @ssmmnn11 comments. --- models/src/anemoi/models/layers/block.py | 10 +++++----- models/src/anemoi/models/layers/chunk.py | 1 - training/src/anemoi/training/config/model/gnn.yaml | 3 +-- .../anemoi/training/config/model/graphtransformer.yaml | 1 - .../src/anemoi/training/config/model/transformer.yaml | 1 - 5 files changed, 6 insertions(+), 10 deletions(-) diff --git a/models/src/anemoi/models/layers/block.py b/models/src/anemoi/models/layers/block.py index dc16f91f..e04c6ed1 100644 --- a/models/src/anemoi/models/layers/block.py +++ b/models/src/anemoi/models/layers/block.py @@ -357,7 +357,6 @@ def __init__( self.conv = GraphTransformerConv(out_channels=self.out_channels_conv) - # Why does the GraphTransformer not have a layer_norm_mlp like the Transformer? self.projection = linear(out_channels, out_channels) try: @@ -500,7 +499,8 @@ def __init__( **kwargs, ) - self.layer_norm_attention_2 = layer_kernels["LayerNorm"](normalized_shape=in_channels) + self.layer_norm_attention_src = self.layer_norm_attention + self.layer_norm_attention_dest = layer_kernels["LayerNorm"](normalized_shape=in_channels) def forward( self, @@ -515,9 +515,9 @@ def forward( x_skip = x x = ( - self.layer_norm_attention(x[0]), - self.layer_norm_attention_2(x[1]), - ) # Why does this use layer_norm_attention_2? And only is a mapper thing? + self.layer_norm_attention_src(x[0]), + self.layer_norm_attention_dest(x[1]), + ) x_r = self.lin_self(x[1]) query = self.lin_query(x[1]) diff --git a/models/src/anemoi/models/layers/chunk.py b/models/src/anemoi/models/layers/chunk.py index e5d6abed..169b2919 100644 --- a/models/src/anemoi/models/layers/chunk.py +++ b/models/src/anemoi/models/layers/chunk.py @@ -38,7 +38,6 @@ def __init__( num_layers: int, *args, activation: str = "GELU", - layer_norm: Optional[dict] = None, **kwargs, ) -> None: """Initialize BaseProcessorChunk.""" diff --git a/training/src/anemoi/training/config/model/gnn.yaml b/training/src/anemoi/training/config/model/gnn.yaml index 00b04500..7c6d671b 100644 --- a/training/src/anemoi/training/config/model/gnn.yaml +++ b/training/src/anemoi/training/config/model/gnn.yaml @@ -6,11 +6,10 @@ model: layer_kernels: LayerNorm: - _target_: torch.nn.LayerNorm #the default PyTorch implementation + _target_: anemoi.models.layers.normalization.AutocastLayerNorm #the default PyTorch implementation _partial_: True #Any arguments to your chosen function go here e.g. Linear: - #_target_: "transformer_engine.pytorch.Linear" _target_: torch.nn.Linear _partial_: True #Any arguments to your chosen function go here e.g. diff --git a/training/src/anemoi/training/config/model/graphtransformer.yaml b/training/src/anemoi/training/config/model/graphtransformer.yaml index 63e4634b..39112082 100644 --- a/training/src/anemoi/training/config/model/graphtransformer.yaml +++ b/training/src/anemoi/training/config/model/graphtransformer.yaml @@ -10,7 +10,6 @@ layer_kernels: _partial_: True #Any arguments to your chosen function go here e.g. Linear: - #_target_: "transformer_engine.pytorch.Linear" _target_: torch.nn.Linear _partial_: True #Any arguments to your chosen function go here e.g. diff --git a/training/src/anemoi/training/config/model/transformer.yaml b/training/src/anemoi/training/config/model/transformer.yaml index 142485f2..d8d7764f 100644 --- a/training/src/anemoi/training/config/model/transformer.yaml +++ b/training/src/anemoi/training/config/model/transformer.yaml @@ -10,7 +10,6 @@ layer_kernels: _partial_: True #Any arguments to your chosen function go here e.g. Linear: - #_target_: "transformer_engine.pytorch.Linear" _target_: torch.nn.Linear _partial_: True #Any arguments to your chosen function go here e.g. From e4da1c0e71687df431d5b679b5a0318626000496 Mon Sep 17 00:00:00 2001 From: Jakob Schloer Date: Wed, 8 Jan 2025 16:48:18 +0000 Subject: [PATCH 14/15] Change docstring style --- .../src/anemoi/models/layers/normalization.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/models/src/anemoi/models/layers/normalization.py b/models/src/anemoi/models/layers/normalization.py index cfd99c7a..f2773a9b 100644 --- a/models/src/anemoi/models/layers/normalization.py +++ b/models/src/anemoi/models/layers/normalization.py @@ -61,13 +61,17 @@ def __init__( def forward(self, input: List[Tensor, Tensor]) -> Tensor: """Conditional Layer Normalization. - Args: - input (List[Tensor, Tensor]): A list of two tensors (x, cond), - the first is the input tensor and - the second is the condition tensor. - - Returns: - Tensor: The output tensor. + Parameters + ---------- + input : List[Tensor, Tensor] + A list of two tensors (x, cond), + the first is the input tensor and + the second is the condition tensor. + + Returns + ------- + Tensor + The output tensor. """ x, cond = input scale = self.scale(cond) From 26ca381358906a1dfc307ff91ca7367f9f28de43 Mon Sep 17 00:00:00 2001 From: Jakob Schloer Date: Wed, 8 Jan 2025 21:30:55 +0000 Subject: [PATCH 15/15] Remove comments in config. --- training/src/anemoi/training/config/model/gnn.yaml | 2 +- training/src/anemoi/training/config/model/graphtransformer.yaml | 2 +- training/src/anemoi/training/config/model/transformer.yaml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/training/src/anemoi/training/config/model/gnn.yaml b/training/src/anemoi/training/config/model/gnn.yaml index 7c6d671b..acaf2a7a 100644 --- a/training/src/anemoi/training/config/model/gnn.yaml +++ b/training/src/anemoi/training/config/model/gnn.yaml @@ -6,7 +6,7 @@ model: layer_kernels: LayerNorm: - _target_: anemoi.models.layers.normalization.AutocastLayerNorm #the default PyTorch implementation + _target_: anemoi.models.layers.normalization.AutocastLayerNorm _partial_: True #Any arguments to your chosen function go here e.g. Linear: diff --git a/training/src/anemoi/training/config/model/graphtransformer.yaml b/training/src/anemoi/training/config/model/graphtransformer.yaml index 39112082..f8d7a3f1 100644 --- a/training/src/anemoi/training/config/model/graphtransformer.yaml +++ b/training/src/anemoi/training/config/model/graphtransformer.yaml @@ -6,7 +6,7 @@ model: layer_kernels: LayerNorm: - _target_: torch.nn.LayerNorm #the default PyTorch implementation + _target_: torch.nn.LayerNorm _partial_: True #Any arguments to your chosen function go here e.g. Linear: diff --git a/training/src/anemoi/training/config/model/transformer.yaml b/training/src/anemoi/training/config/model/transformer.yaml index d8d7764f..1b5c37fe 100644 --- a/training/src/anemoi/training/config/model/transformer.yaml +++ b/training/src/anemoi/training/config/model/transformer.yaml @@ -6,7 +6,7 @@ model: layer_kernels: LayerNorm: - _target_: torch.nn.LayerNorm #the default PyTorch implementation + _target_: torch.nn.LayerNorm _partial_: True #Any arguments to your chosen function go here e.g. Linear: