Skip to content

Commit

Permalink
Update comments
Browse files Browse the repository at this point in the history
  • Loading branch information
hanzhi713 committed Jan 14, 2025
1 parent 2d33641 commit f76f51f
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions axlearn/common/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4006,15 +4006,17 @@ class RematRegexSavePatterns(enum.Enum):
CONTEXT = r".*context"
LINEAR1_X = r".*linear1_[01]"
LINEAR2_X = r".*linear2_[01]"
SELF_ATTENTION = ".*([qkvo]_proj|context)"
# This is called native attention because the "context" remat point only exists when using
# native attention, e.g. `MultiheadAttention` or `GroupedQueryAttention`.
NATIVE_ATTENTION = ".*([qkvo]_proj|context)"
FEED_FORWARD = "|".join([LINEAR1_X, LINEAR2_X])


def build_remat_spec(
stack_cfg: Union[
BaseStackedTransformerLayer.Config, "RepeatedConformerLayer.Config" # type: ignore
],
save_pattern: SavePattern = RematRegexSavePatterns.SELF_ATTENTION.value,
save_pattern: SavePattern = RematRegexSavePatterns.NATIVE_ATTENTION.value,
offload_pattern: SavePattern = None,
offload_dst: str = "pinned_host",
) -> Optional[RematSpec]:
Expand All @@ -4028,8 +4030,10 @@ def build_remat_spec(
TODO(zhiyunlu): investigate Conformer model's memory/step-time tradeoffs. Possibly we
need to save points in the LConv module.
Note that the default `save_pattern`, `NATIVE_ATTENTION_SAVE_PATTERN`, doesn't save the
context tensor when using FlashAttention. To save it when using FlashAttention, do this
Note that the default `save_pattern`, `NATIVE_ATTENTION`, doesn't save the context tensor when
using `FlashAttention`. To save it when using `FlashAttention`, use the policy from the module
`axlearn.common.flash_attention.remat`:
```python
from axlearn.common.utils import save_and_offload_these_names_regex
from axlearn.common.flash_attention.remat import save_or_offload_flash_attention_policy
Expand Down

0 comments on commit f76f51f

Please sign in to comment.