diff --git a/applications/DeepSpeed-Chat/dschat/utils/model/reward_model.py b/applications/DeepSpeed-Chat/dschat/utils/model/reward_model.py index 60d063b18..4f29d0dd8 100644 --- a/applications/DeepSpeed-Chat/dschat/utils/model/reward_model.py +++ b/applications/DeepSpeed-Chat/dschat/utils/model/reward_model.py @@ -98,8 +98,8 @@ def forward(self, else: # Check if there is any padding otherwise take length of sequence r_inds = (rejected_id == self.PAD_ID).nonzero() - r_ind = r_inds[self.num_padding_at_beginning].item( - ) if len(r_inds) > self.num_padding_at_beginning else seq_len + r_ind = r_inds[self.num_padding_at_beginning].item() if len( + r_inds) > self.num_padding_at_beginning else seq_len end_ind = max(c_ind, r_ind) divergence_ind = check_divergence[0] assert divergence_ind > 0 diff --git a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py index f3db70e05..1378dc4e6 100644 --- a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py +++ b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py @@ -246,10 +246,9 @@ def parse_args(): '--offload_reference_model', action='store_true', help='Enable ZeRO Offload techniques for reference model') - parser.add_argument( - '--offload_reward_model', - action='store_true', - help='Enable ZeRO Offload techniques for reward model') + parser.add_argument('--offload_reward_model', + action='store_true', + help='Enable ZeRO Offload techniques for reward model') parser.add_argument( '--actor_zero_stage', type=int,