Skip to content

Commit e46aa89

Browse files
committed
fix indexing
Signed-off-by: ashors1 <ashors@nvidia.com>
1 parent af26c69 commit e46aa89

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

nemo_rl/algorithms/dpo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ def dpo_train(
397397
)
398398

399399
is_last_step = total_steps >= master_config["dpo"]["max_num_steps"] or (
400-
current_epoch + 1 == max_num_epochs
400+
current_epoch == max_num_epochs
401401
and current_step == len(train_dataloader)
402402
)
403403

@@ -466,4 +466,4 @@ def dpo_train(
466466
return
467467

468468
current_epoch += 1
469-
current_step = 0 # Reset step counter for new epoch
469+
current_step = 1 # Reset step counter for new epoch

nemo_rl/algorithms/sft.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ def sft_train(
410410
train_results = policy.train(train_data, loss_fn)
411411

412412
is_last_step = total_steps >= master_config["sft"]["max_num_steps"] or (
413-
current_epoch + 1 == max_num_epochs
413+
current_epoch == max_num_epochs
414414
and current_step == len(train_dataloader)
415415
)
416416

@@ -487,4 +487,4 @@ def sft_train(
487487
return
488488

489489
current_epoch += 1
490-
current_step = 0 # Reset step counter for new epoch
490+
current_step = 1 # Reset step counter for new epoch

0 commit comments

Comments
 (0)