Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
hanzhi713 committed Jan 13, 2025
1 parent 24071f5 commit bff51dc
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions axlearn/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,16 @@ def save_and_offload_only_these_names_regex(
offload_dst: str,
) -> RematPolicy:
"""Adapted from jax source code to support regex.
Reference:
https://github.com/jax-ml/jax/blob/0d36b0b433a93c707f86dac89b0c05d40302775a/jax/_src/ad_checkpoint.py#L120
Args:
names_which_can_be_saved: A regex pattern for names which can be saved.
names_which_can_be_offloaded: A regex pattern for names which can be offloaded.
offload_src: The source device for offloading.
offload_dst: The target device for offloading.
Returns:
A policy function that offloads and saves only the tensors that match the given
regex patterns.
Expand Down

0 comments on commit bff51dc

Please sign in to comment.