diff --git a/axlearn/common/attention.py b/axlearn/common/attention.py index 9c497509..02a88e38 100644 --- a/axlearn/common/attention.py +++ b/axlearn/common/attention.py @@ -4006,7 +4006,9 @@ class RematRegexSavePatterns(enum.Enum): CONTEXT = r".*context" LINEAR1_X = r".*linear1_[01]" LINEAR2_X = r".*linear2_[01]" - SELF_ATTENTION = ".*([qkvo]_proj|context)" + # This is called native attention because the "context" remat point only exists when using + # native attention, e.g. `MultiheadAttention` or `GroupedQueryAttention`. + NATIVE_ATTENTION = ".*([qkvo]_proj|context)" FEED_FORWARD = "|".join([LINEAR1_X, LINEAR2_X]) @@ -4014,7 +4016,7 @@ def build_remat_spec( stack_cfg: Union[ BaseStackedTransformerLayer.Config, "RepeatedConformerLayer.Config" # type: ignore ], - save_pattern: SavePattern = RematRegexSavePatterns.SELF_ATTENTION.value, + save_pattern: SavePattern = RematRegexSavePatterns.NATIVE_ATTENTION.value, offload_pattern: SavePattern = None, offload_dst: str = "pinned_host", ) -> Optional[RematSpec]: @@ -4028,8 +4030,10 @@ def build_remat_spec( TODO(zhiyunlu): investigate Conformer model's memory/step-time tradeoffs. Possibly we need to save points in the LConv module. - Note that the default `save_pattern`, `NATIVE_ATTENTION_SAVE_PATTERN`, doesn't save the - context tensor when using FlashAttention. To save it when using FlashAttention, do this + Note that the default `save_pattern`, `NATIVE_ATTENTION`, doesn't save the context tensor when + using `FlashAttention`. To save it when using `FlashAttention`, use the policy from the module + `axlearn.common.flash_attention.remat`: + ```python from axlearn.common.utils import save_and_offload_these_names_regex from axlearn.common.flash_attention.remat import save_or_offload_flash_attention_policy