Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Proposal to Update PPO Test to Add LR Scheduler #2423

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion recipes/configs/mistral/7B_full_ppo_low_memory.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ output_dir: /tmp/torchtune/mistral_7B/full_ppo_low_memory # /tmp may be deleted
tokenizer:
_component_: torchtune.models.mistral.mistral_tokenizer
path: /tmp/Mistral-7B-Instruct-v0.2/tokenizer.model
max_seq_len: null
max_seq_len: 512

# Dataset
dataset:
Expand Down Expand Up @@ -205,3 +205,7 @@ profiler:
warmup_steps: 3
active_steps: 3
num_cycles: 1

lr_scheduler:
_component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
num_warmup_steps: 100
75 changes: 75 additions & 0 deletions recipes/ppo_full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from torchtune.recipe_interfaces import FTRecipeInterface
from torchtune.rlhf import PPOStats, Trajectory
from torchtune.training import disable_dropout, DummyProfiler, PROFILER_KEY
from torchtune.training.lr_schedulers import get_lr
from tqdm import tqdm

log = utils.get_logger("DEBUG")
Expand Down Expand Up @@ -257,6 +258,18 @@ def setup(self, cfg: DictConfig) -> None:
* (self.batch_size // self._ppo_batch_size)
)

self._steps_per_epoch = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should name this lr_steps and the correct value would be the total number of optimizer steps being taken, which should be self._total_steps * self._ppo_epochs * (self.batch_size // self._ppo_batch_size), right?

len(self._dataloader) // self._gradient_accumulation_steps
)
self.global_step = self._epochs_run * self._steps_per_epoch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

global_step is already defined here

- how come you're re-defining it here?


# Setup lr scheduler
self._lr_scheduler = self._setup_lr_scheduler(
cfg_lr_scheduler=cfg.get("lr_scheduler", None),
num_training_steps=self._total_epochs * self._steps_per_epoch,
last_epoch=self.global_step - 1,
)

# Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method)
# if cfg is missing profiler key or if `cfg.profiler.enabled = False`
self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None))
Expand Down Expand Up @@ -328,6 +341,53 @@ def _setup_profiler(

return profiler

def _setup_lr_scheduler(
self,
cfg_lr_scheduler: Optional[DictConfig],
num_training_steps: int,
last_epoch: int,
) -> Optional[Optimizer]:
"""
Set up the learning rate scheduler based on the provided configuration.
It handles both standard optimization and optimizer-in-backward cases, and supports
schedulers from both torchtune.modules and torch.optim.

Args:
cfg_lr_scheduler (Optional[DictConfig]): The learning rate scheduler configuration.
num_training_steps (int): The total number of training steps.
last_epoch (int): The index of the last epoch.

Returns:
lr_scheduler (Optional[Optimizer]): The learning rate scheduler.
"""
if cfg_lr_scheduler is None:
log.info(
"No learning rate scheduler configured. Using constant learning rate."
)
return None

if self._optimizer_in_bwd:
# Use the first optimizer from the wrapper to represent the learning rate
optimizer = next(iter(self._optim_ckpt_wrapper.optim_map.values()))
else:
# Standard case: use the single optimizer
optimizer = self._optimizer

# Instantiate the learning rate scheduler
lr_scheduler = config.instantiate(
cfg_lr_scheduler,
optimizer,
num_training_steps=num_training_steps,
last_epoch=last_epoch,
)

if self._optimizer_in_bwd:
# Modify the scheduler for optimizer_in_bwd case
self._optim_ckpt_wrapper.set_lr_scheduler(lr_scheduler)

log.info("Learning rate scheduler is initialized.")
return lr_scheduler

def _setup_training_hyperparameters(self, cfg) -> None:
"""
Sets up the training hyperparameters for the recipe. This includes the GAE hyperparameters,
Expand Down Expand Up @@ -1000,9 +1060,21 @@ def train(self) -> None:
self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)

# Need to fix `lr_scheduler.step()` before `optimizer.step()` warning
if self._lr_scheduler is not None:
self._lr_scheduler.step()
self.global_step += 1

ppo_time = time.perf_counter() - t0_ppo

current_lr = get_lr(
(
self._optimizer
if not self._optimizer_in_bwd
else self._optim_ckpt_wrapper
),
)

# step 5. profit
self._steps_run += 1
if self._steps_run % self._log_every_n_steps == 0:
Expand All @@ -1013,6 +1085,7 @@ def train(self) -> None:
kl_rewards,
num_tokens / traj_time,
num_tokens / ppo_time,
current_lr,
)
self.cleanup_after_step(
trajectory, ppo_stats, advantages, returns, kl, kl_rewards
Expand Down Expand Up @@ -1139,6 +1212,7 @@ def log_metrics(
kl_rewards: torch.Tensor,
tokens_per_second_trajectory: torch.Tensor,
tokens_per_second_loss: torch.Tensor,
lr: float,
) -> None:
"""
Log metrics and statistics for the current step to the metric logger.
Expand All @@ -1149,6 +1223,7 @@ def log_metrics(
"rlhf_reward": trajectory.scores.mean() + kl_rewards.sum(1).mean(),
"kl": kl.sum(1).mean(),
"kl_reward": kl_rewards.sum(1).mean(),
"lr": lr,
"loss": ppo_stats.loss.mean(),
"policy_loss": ppo_stats.policy_loss.mean(),
"value_loss": ppo_stats.value_loss.mean(),
Expand Down
2 changes: 2 additions & 0 deletions tests/recipes/test_ppo_full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def _get_test_config_overrides(self):
"seed=9",
"optimizer=torch.optim.AdamW",
"optimizer.lr=2e-5",
"lr_scheduler.num_warmup_steps=0",
"lr_scheduler.num_cycles=0",
"log_every_n_steps=1",
"compile=False",
] + dummy_text_completion_alpaca_dataset_config()
Expand Down