File tree 2 files changed +4
-4
lines changed
2 files changed +4
-4
lines changed Original file line number Diff line number Diff line change @@ -397,7 +397,7 @@ def dpo_train(
397
397
)
398
398
399
399
is_last_step = total_steps >= master_config ["dpo" ]["max_num_steps" ] or (
400
- current_epoch + 1 == max_num_epochs
400
+ current_epoch == max_num_epochs
401
401
and current_step == len (train_dataloader )
402
402
)
403
403
@@ -466,4 +466,4 @@ def dpo_train(
466
466
return
467
467
468
468
current_epoch += 1
469
- current_step = 0 # Reset step counter for new epoch
469
+ current_step = 1 # Reset step counter for new epoch
Original file line number Diff line number Diff line change @@ -410,7 +410,7 @@ def sft_train(
410
410
train_results = policy .train (train_data , loss_fn )
411
411
412
412
is_last_step = total_steps >= master_config ["sft" ]["max_num_steps" ] or (
413
- current_epoch + 1 == max_num_epochs
413
+ current_epoch == max_num_epochs
414
414
and current_step == len (train_dataloader )
415
415
)
416
416
@@ -487,4 +487,4 @@ def sft_train(
487
487
return
488
488
489
489
current_epoch += 1
490
- current_step = 0 # Reset step counter for new epoch
490
+ current_step = 1 # Reset step counter for new epoch
You can’t perform that action at this time.
0 commit comments