Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
hanzhi713 committed Jan 16, 2025
1 parent 9a04606 commit e63ae33
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
9 changes: 7 additions & 2 deletions axlearn/common/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ class Config(Module.Config):
# pre-compilation checks (such as sharding check) that increases the step time for some
# models. Note that this cache is always disabled at steps when xsc is enabled.
# Defaults to None which is interpreted as True.
enable_python_train_step_cache: Optional[bool] = None
cache_python_train_step: Optional[bool] = None

def __init__(
self,
Expand Down Expand Up @@ -279,6 +279,8 @@ def __init__(
"xsc_check_policy was set for non-TPU XLA backend. Running without XSC."
)
else:
if cfg.cache_python_train_step is True:
raise ValueError("cache_python_train_step cannot be True when xsc is enabled.")
xsc_check_policy = maybe_instantiate(cfg.xsc_check_policy)
self._xsc_check_policy: Optional[Callable[[int], bool]] = xsc_check_policy
self._compiled_train_step: Optional[jax.stages.Compiled] = None
Expand Down Expand Up @@ -973,6 +975,9 @@ def _get_compiled_train_step_fn(
) -> Callable[[TrainerState, NestedTensor], tuple[TrainerState, NestedTensor]]:
"""Build a fully compiled train step function.
Relies on the JAX pjit cache to avoid recompilation when with_xsc=True or
cache_python_train_step=False.
Args:
train_state: A TrainerState instance.
input_batch: A NestedTensor containing global arrays.
Expand All @@ -985,7 +990,7 @@ def _get_compiled_train_step_fn(
RuntimeError: If `with_xsc` is requested on heterogenous device kinds.
"""
if (
not (self.config.enable_python_train_step_cache is False)
not (self.config.cache_python_train_step is False)
and not with_xsc
and self._compiled_train_step is not None
):
Expand Down
2 changes: 1 addition & 1 deletion axlearn/common/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ def test_xsc_check_policy_and_compilation_cache(
cfg.vlog = 2
# Set XSC policy.
cfg.xsc_check_policy = lambda step: (step in [7, 8])
cfg.enable_python_train_step_cache = enable_python_cache
cfg.cache_python_train_step = enable_python_cache

# Test training run.
trainer: SpmdTrainer = cfg.set(max_step=12).instantiate(parent=None)
Expand Down

0 comments on commit e63ae33

Please sign in to comment.