-
Notifications
You must be signed in to change notification settings - Fork 553
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Proposal to Update PPO Test to Add LR Scheduler #2423
Open
Seoley
wants to merge
4
commits into
pytorch:main
Choose a base branch
from
Seoley:lr_scheduler_to_PPO
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+82
−1
Open
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -25,6 +25,7 @@ | |||
from torchtune.recipe_interfaces import FTRecipeInterface | ||||
from torchtune.rlhf import PPOStats, Trajectory | ||||
from torchtune.training import disable_dropout, DummyProfiler, PROFILER_KEY | ||||
from torchtune.training.lr_schedulers import get_lr | ||||
from tqdm import tqdm | ||||
|
||||
log = utils.get_logger("DEBUG") | ||||
|
@@ -257,6 +258,18 @@ def setup(self, cfg: DictConfig) -> None: | |||
* (self.batch_size // self._ppo_batch_size) | ||||
) | ||||
|
||||
self._steps_per_epoch = ( | ||||
len(self._dataloader) // self._gradient_accumulation_steps | ||||
) | ||||
self.global_step = self._epochs_run * self._steps_per_epoch | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||
|
||||
# Setup lr scheduler | ||||
self._lr_scheduler = self._setup_lr_scheduler( | ||||
cfg_lr_scheduler=cfg.get("lr_scheduler", None), | ||||
num_training_steps=self._total_epochs * self._steps_per_epoch, | ||||
last_epoch=self.global_step - 1, | ||||
) | ||||
|
||||
# Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method) | ||||
# if cfg is missing profiler key or if `cfg.profiler.enabled = False` | ||||
self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None)) | ||||
|
@@ -328,6 +341,53 @@ def _setup_profiler( | |||
|
||||
return profiler | ||||
|
||||
def _setup_lr_scheduler( | ||||
self, | ||||
cfg_lr_scheduler: Optional[DictConfig], | ||||
num_training_steps: int, | ||||
last_epoch: int, | ||||
) -> Optional[Optimizer]: | ||||
""" | ||||
Set up the learning rate scheduler based on the provided configuration. | ||||
It handles both standard optimization and optimizer-in-backward cases, and supports | ||||
schedulers from both torchtune.modules and torch.optim. | ||||
|
||||
Args: | ||||
cfg_lr_scheduler (Optional[DictConfig]): The learning rate scheduler configuration. | ||||
num_training_steps (int): The total number of training steps. | ||||
last_epoch (int): The index of the last epoch. | ||||
|
||||
Returns: | ||||
lr_scheduler (Optional[Optimizer]): The learning rate scheduler. | ||||
""" | ||||
if cfg_lr_scheduler is None: | ||||
log.info( | ||||
"No learning rate scheduler configured. Using constant learning rate." | ||||
) | ||||
return None | ||||
|
||||
if self._optimizer_in_bwd: | ||||
# Use the first optimizer from the wrapper to represent the learning rate | ||||
optimizer = next(iter(self._optim_ckpt_wrapper.optim_map.values())) | ||||
else: | ||||
# Standard case: use the single optimizer | ||||
optimizer = self._optimizer | ||||
|
||||
# Instantiate the learning rate scheduler | ||||
lr_scheduler = config.instantiate( | ||||
cfg_lr_scheduler, | ||||
optimizer, | ||||
num_training_steps=num_training_steps, | ||||
last_epoch=last_epoch, | ||||
) | ||||
|
||||
if self._optimizer_in_bwd: | ||||
# Modify the scheduler for optimizer_in_bwd case | ||||
self._optim_ckpt_wrapper.set_lr_scheduler(lr_scheduler) | ||||
|
||||
log.info("Learning rate scheduler is initialized.") | ||||
return lr_scheduler | ||||
|
||||
def _setup_training_hyperparameters(self, cfg) -> None: | ||||
""" | ||||
Sets up the training hyperparameters for the recipe. This includes the GAE hyperparameters, | ||||
|
@@ -1000,9 +1060,21 @@ def train(self) -> None: | |||
self._optimizer.step() | ||||
self._optimizer.zero_grad(set_to_none=True) | ||||
|
||||
# Need to fix `lr_scheduler.step()` before `optimizer.step()` warning | ||||
if self._lr_scheduler is not None: | ||||
self._lr_scheduler.step() | ||||
self.global_step += 1 | ||||
|
||||
ppo_time = time.perf_counter() - t0_ppo | ||||
|
||||
current_lr = get_lr( | ||||
( | ||||
self._optimizer | ||||
if not self._optimizer_in_bwd | ||||
else self._optim_ckpt_wrapper | ||||
), | ||||
) | ||||
|
||||
# step 5. profit | ||||
self._steps_run += 1 | ||||
if self._steps_run % self._log_every_n_steps == 0: | ||||
|
@@ -1013,6 +1085,7 @@ def train(self) -> None: | |||
kl_rewards, | ||||
num_tokens / traj_time, | ||||
num_tokens / ppo_time, | ||||
current_lr, | ||||
) | ||||
self.cleanup_after_step( | ||||
trajectory, ppo_stats, advantages, returns, kl, kl_rewards | ||||
|
@@ -1139,6 +1212,7 @@ def log_metrics( | |||
kl_rewards: torch.Tensor, | ||||
tokens_per_second_trajectory: torch.Tensor, | ||||
tokens_per_second_loss: torch.Tensor, | ||||
lr: float, | ||||
) -> None: | ||||
""" | ||||
Log metrics and statistics for the current step to the metric logger. | ||||
|
@@ -1149,6 +1223,7 @@ def log_metrics( | |||
"rlhf_reward": trajectory.scores.mean() + kl_rewards.sum(1).mean(), | ||||
"kl": kl.sum(1).mean(), | ||||
"kl_reward": kl_rewards.sum(1).mean(), | ||||
"lr": lr, | ||||
"loss": ppo_stats.loss.mean(), | ||||
"policy_loss": ppo_stats.policy_loss.mean(), | ||||
"value_loss": ppo_stats.value_loss.mean(), | ||||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should name this
lr_steps
and the correct value would be the total number of optimizer steps being taken, which should beself._total_steps * self._ppo_epochs * (self.batch_size // self._ppo_batch_size)
, right?