From e63ae33bbb9791b35d0eecb043a582f8ae492fcf Mon Sep 17 00:00:00 2001 From: Hanzhi Zhou Date: Wed, 15 Jan 2025 20:07:41 -0800 Subject: [PATCH] Fix --- axlearn/common/trainer.py | 9 +++++++-- axlearn/common/trainer_test.py | 2 +- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/axlearn/common/trainer.py b/axlearn/common/trainer.py index f8832fd9..fa4417d3 100644 --- a/axlearn/common/trainer.py +++ b/axlearn/common/trainer.py @@ -206,7 +206,7 @@ class Config(Module.Config): # pre-compilation checks (such as sharding check) that increases the step time for some # models. Note that this cache is always disabled at steps when xsc is enabled. # Defaults to None which is interpreted as True. - enable_python_train_step_cache: Optional[bool] = None + cache_python_train_step: Optional[bool] = None def __init__( self, @@ -279,6 +279,8 @@ def __init__( "xsc_check_policy was set for non-TPU XLA backend. Running without XSC." ) else: + if cfg.cache_python_train_step is True: + raise ValueError("cache_python_train_step cannot be True when xsc is enabled.") xsc_check_policy = maybe_instantiate(cfg.xsc_check_policy) self._xsc_check_policy: Optional[Callable[[int], bool]] = xsc_check_policy self._compiled_train_step: Optional[jax.stages.Compiled] = None @@ -973,6 +975,9 @@ def _get_compiled_train_step_fn( ) -> Callable[[TrainerState, NestedTensor], tuple[TrainerState, NestedTensor]]: """Build a fully compiled train step function. + Relies on the JAX pjit cache to avoid recompilation when with_xsc=True or + cache_python_train_step=False. + Args: train_state: A TrainerState instance. input_batch: A NestedTensor containing global arrays. @@ -985,7 +990,7 @@ def _get_compiled_train_step_fn( RuntimeError: If `with_xsc` is requested on heterogenous device kinds. """ if ( - not (self.config.enable_python_train_step_cache is False) + not (self.config.cache_python_train_step is False) and not with_xsc and self._compiled_train_step is not None ): diff --git a/axlearn/common/trainer_test.py b/axlearn/common/trainer_test.py index eebf0623..4d4c9ff3 100644 --- a/axlearn/common/trainer_test.py +++ b/axlearn/common/trainer_test.py @@ -488,7 +488,7 @@ def test_xsc_check_policy_and_compilation_cache( cfg.vlog = 2 # Set XSC policy. cfg.xsc_check_policy = lambda step: (step in [7, 8]) - cfg.enable_python_train_step_cache = enable_python_cache + cfg.cache_python_train_step = enable_python_cache # Test training run. trainer: SpmdTrainer = cfg.set(max_step=12).instantiate(parent=None)