From 2babdd3c260d47362dc09873f6a1e2cc16af8feb Mon Sep 17 00:00:00 2001 From: Konstantinos Fertakis Date: Tue, 15 Oct 2024 10:13:04 +0100 Subject: [PATCH] more formatting fixes --- .../DeepSpeed-Chat/dschat/utils/model/reward_model.py | 4 ++-- .../DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py | 7 +++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/applications/DeepSpeed-Chat/dschat/utils/model/reward_model.py b/applications/DeepSpeed-Chat/dschat/utils/model/reward_model.py index 60d063b18..4f29d0dd8 100644 --- a/applications/DeepSpeed-Chat/dschat/utils/model/reward_model.py +++ b/applications/DeepSpeed-Chat/dschat/utils/model/reward_model.py @@ -98,8 +98,8 @@ def forward(self, else: # Check if there is any padding otherwise take length of sequence r_inds = (rejected_id == self.PAD_ID).nonzero() - r_ind = r_inds[self.num_padding_at_beginning].item( - ) if len(r_inds) > self.num_padding_at_beginning else seq_len + r_ind = r_inds[self.num_padding_at_beginning].item() if len( + r_inds) > self.num_padding_at_beginning else seq_len end_ind = max(c_ind, r_ind) divergence_ind = check_divergence[0] assert divergence_ind > 0 diff --git a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py index f3db70e05..1378dc4e6 100644 --- a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py +++ b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py @@ -246,10 +246,9 @@ def parse_args(): '--offload_reference_model', action='store_true', help='Enable ZeRO Offload techniques for reference model') - parser.add_argument( - '--offload_reward_model', - action='store_true', - help='Enable ZeRO Offload techniques for reward model') + parser.add_argument('--offload_reward_model', + action='store_true', + help='Enable ZeRO Offload techniques for reward model') parser.add_argument( '--actor_zero_stage', type=int,