Skip to content

Commit

Permalink
Un-add files
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams committed Nov 4, 2024
1 parent d378c25 commit fa87108
Show file tree
Hide file tree
Showing 20 changed files with 266 additions and 594 deletions.
40 changes: 25 additions & 15 deletions applications/DeepSpeed-Chat/dschat/rlhf/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@

def print_all_ranks(tag, value, rank):
world_size = torch.distributed.get_world_size()
all_tensor = torch.zeros(world_size, dtype=torch.float32, device=value.device)
all_tensor = torch.zeros(world_size,
dtype=torch.float32,
device=value.device)
all_tensor[rank] = value
torch.distributed.all_reduce(all_tensor, op=torch.distributed.ReduceOp.SUM)
print_rank_0(f'{tag} {all_tensor}', rank)
Expand Down Expand Up @@ -53,7 +55,8 @@ def __init__(self, rlhf_engine, args):
self.end_of_conversation_token_id = self.tokenizer(
args.end_of_conversation_token)['input_ids'][-1]
self.z3_enabled = args.actor_zero_stage == 3
self.calculate_fp32_loss = (self.args.dtype == "bf16") and self.args.bf16_to_fp32_loss
self.calculate_fp32_loss = (self.args.dtype
== "bf16") and self.args.bf16_to_fp32_loss

# In case the generated experience is not valid (too short), we use the last valid
# generated experience. Alternatively, we can skip the step (on all workers).
Expand Down Expand Up @@ -89,15 +92,17 @@ def _generate_sequence(self, prompts, mask, step):
if is_hpu() and self.args.enable_hpu_graphs:
orig_actor_model_fwd_fn = self.actor_model.module.forward
if self.first_generate:
self.actor_model.module.forward = thpu.wrap_in_hpu_graph_func(self.actor_model.module.forward)
self.actor_model.module.forward = thpu.wrap_in_hpu_graph_func(
self.actor_model.module.forward)
self.first_generate = False
else:
self.actor_model.module.forward = self.actor_model_hpu_graph_wrapped_fwd_fn
seq = self.actor_model.module.generate(prompts,
attention_mask=mask,
max_length=max_min_length,
min_length=max_min_length,
lazy_mode=True)
seq = self.actor_model.module.generate(
prompts,
attention_mask=mask,
max_length=max_min_length,
min_length=max_min_length,
lazy_mode=True)
self.actor_model_hpu_graph_wrapped_fwd_fn = self.actor_model.module.forward
self.actor_model.module.forward = orig_actor_model_fwd_fn
else:
Expand All @@ -117,7 +122,8 @@ def _generate_sequence(self, prompts, mask, step):
ans = seq[:, prompt_length:]
valid_ans_len = (ans != self.tokenizer.pad_token_id).sum(dim=-1)

if self.args.print_answers and (step % self.args.print_answers_interval == 0):
if self.args.print_answers and (step % self.args.print_answers_interval
== 0):
print(
f"--- prompt --> step={step}, rank={torch.distributed.get_rank()}, {self.tokenizer.batch_decode(prompts, skip_special_tokens=True)}"
)
Expand All @@ -129,17 +135,21 @@ def _generate_sequence(self, prompts, mask, step):
for i in range(batch_size):
if valid_ans_len[
i] <= 1: # if the answer is shorter than 1 token, drop it
print(f'Dropping too short generated answer: {step=}: \n'
f'prompts: {self.tokenizer.batch_decode(prompts, skip_special_tokens=False)}\n'
f'answers: {self.tokenizer.batch_decode(ans, skip_special_tokens=False)}')
print(
f'Dropping too short generated answer: {step=}: \n'
f'prompts: {self.tokenizer.batch_decode(prompts, skip_special_tokens=False)}\n'
f'answers: {self.tokenizer.batch_decode(ans, skip_special_tokens=False)}'
)
continue
else:
out_seq.append(seq[i:i + 1])

if not out_seq:
print(f'All generated results are too short for rank={self.args.local_rank} step={step}\n'
f'-> prompts: {self.tokenizer.batch_decode(prompts, skip_special_tokens=False)}\n'
f'-> answers: {self.tokenizer.batch_decode(ans, skip_special_tokens=False)}')
print(
f'All generated results are too short for rank={self.args.local_rank} step={step}\n'
f'-> prompts: {self.tokenizer.batch_decode(prompts, skip_special_tokens=False)}\n'
f'-> answers: {self.tokenizer.batch_decode(ans, skip_special_tokens=False)}'
)
return None

out_seq = torch.cat(out_seq, dim=0) # concat output in the batch dim
Expand Down
14 changes: 10 additions & 4 deletions applications/DeepSpeed-Chat/dschat/rlhf/rlhf_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ def _init_actor(self, actor_model_name_or_path):
# TODO SW-146776: remove this WA once SW-141762 is resolved
if is_hpu():
import habana_frameworks.torch.core as htcore
actor_model.to(dtype=torch.bfloat16, device=get_accelerator().device())
actor_model.to(dtype=torch.bfloat16,
device=get_accelerator().device())

# Optimizer
if self.args.offload:
Expand All @@ -117,7 +118,9 @@ def _init_actor(self, actor_model_name_or_path):
AdamOptimizer = torch.optim.AdamW
else:
AdamOptimizer = FusedAdam
print_rank_0(f'Using {AdamOptimizer.__name__} optimizer for actor model', self.args.global_rank)
print_rank_0(
f'Using {AdamOptimizer.__name__} optimizer for actor model',
self.args.global_rank)

optim_params = get_optimizer_grouped_parameters(
actor_model, self.args.actor_weight_decay,
Expand Down Expand Up @@ -249,7 +252,8 @@ def _init_critic(self, critic_model_name_or_path):

# TODO SW-146776: remove this WA once SW-141762 is resolved
if is_hpu():
critic_model.to(dtype=torch.bfloat16, device=get_accelerator().device())
critic_model.to(dtype=torch.bfloat16,
device=get_accelerator().device())

# Optimizer
# TODO SW-147425: change the file to use HPEX optimizer instead of AdamW on hpu
Expand All @@ -259,7 +263,9 @@ def _init_critic(self, critic_model_name_or_path):
AdamOptimizer = torch.optim.AdamW
else:
AdamOptimizer = FusedAdam
print_rank_0(f'Using {AdamOptimizer.__name__} optimizer for critic model', self.args.global_rank)
print_rank_0(
f'Using {AdamOptimizer.__name__} optimizer for critic model',
self.args.global_rank)

optim_params = get_optimizer_grouped_parameters(
critic_model, self.args.critic_weight_decay,
Expand Down
6 changes: 4 additions & 2 deletions applications/DeepSpeed-Chat/dschat/utils/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,10 @@ def create_prompt_dataset(local_rank,
eval_fname = f"{output_path}/evaldata_{fname}.pt"

cache_found = os.path.isfile(train_fname) and os.path.isfile(eval_fname)
device = torch.device(get_accelerator().device_name(torch.distributed.get_rank()))
buf_create_cache = get_accelerator().ByteTensor([not cache_found], device=device)
device = torch.device(get_accelerator().device_name(
torch.distributed.get_rank()))
buf_create_cache = get_accelerator().ByteTensor([not cache_found],
device=device)
torch.distributed.all_reduce(buf_create_cache)

if local_rank <= 0 and (buf_create_cache.item() != 0 or reload):
Expand Down
55 changes: 30 additions & 25 deletions applications/DeepSpeed-Chat/dschat/utils/model/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,35 +84,40 @@ def causal_lm_forward(

def configure_dropout(model_config, dropout):
if dropout is not None:
for key in ('dropout', 'attention_dropout', 'hidden_dropout', 'activation_dropout'):
for key in ('dropout', 'attention_dropout', 'hidden_dropout',
'activation_dropout'):
if hasattr(model_config, key):
print(f"Setting model_config.{key} to {dropout}")
setattr(model_config, key, dropout)


def causal_lm_model_to_fp32_loss(model):
""" Convert CausalLM model to calculate loss in fp32 """
def causal_lm_forward(input_ids=None,
past_key_values=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**deprecated_arguments, ):
output = model.__original_forward__(input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
labels=None,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict)

def causal_lm_forward(
input_ids=None,
past_key_values=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**deprecated_arguments,
):
output = model.__original_forward__(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
labels=None,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict)

return_dict = isinstance(output, dict)
lm_logits = output.logits if return_dict else output[0]
Expand All @@ -127,12 +132,12 @@ def causal_lm_forward(input_ids=None,
# Flatten the tokens
loss_fct = torch.nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
)
shift_logits.view(batch_size * seq_length, vocab_size),
shift_labels.view(batch_size * seq_length))

if not return_dict:
# re-pack output with fp32 loss
return ((loss,) + output) if loss is not None else output
return ((loss, ) + output) if loss is not None else output

output.loss = loss
return output
Expand Down
49 changes: 35 additions & 14 deletions applications/DeepSpeed-Chat/dschat/utils/model/reward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
## https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/reward_model.py
class RewardModel(nn.Module):

def __init__(self, base_model, tokenizer, num_padding_at_beginning=0, loss_to_fp32=False, opt_loss_calc=False):
def __init__(self,
base_model,
tokenizer,
num_padding_at_beginning=0,
loss_to_fp32=False,
opt_loss_calc=False):
super().__init__()
self.config = base_model.config
self.num_padding_at_beginning = num_padding_at_beginning
Expand Down Expand Up @@ -112,35 +117,50 @@ def get_last_before_padding(paddings, num_begin_padding):

# united_unpadding_mask will what are the unite between the unpadded elements
# will indicate 1's where we have non padded tokens, in either of the inputs
united_unpadding_mask = torch.logical_not(torch.logical_and(chosen_padding_mask, rejected_padding_mask))
united_unpadding_mask = torch.logical_not(
torch.logical_and(chosen_padding_mask,
rejected_padding_mask))

# get a mask of all the different tokens
divergence_mask = (chosen_id != rejected_id)
divergence_mask = divergence_mask.cumsum(0).bool()

# loss mask indicates the elements which should be taken into consideration after sigmoid calc
# from the first divergence, till the last non padded token
loss_mask = torch.logical_and(divergence_mask, united_unpadding_mask)
loss_mask = torch.where(divergence_mask.sum().bool(), loss_mask, self.fallback_mask)
loss_mask = torch.logical_and(divergence_mask,
united_unpadding_mask)
loss_mask = torch.where(divergence_mask.sum().bool(),
loss_mask, self.fallback_mask)

# calc logsigmoid on all the input and mask the not interesting ones
if self.loss_to_fp32:
chosen_reward = chosen_reward.float()
rejected_reward = rejected_reward.float()
logsigmoid = torch.nn.functional.logsigmoid(chosen_reward.float() - rejected_reward.float()) * loss_mask
logsigmoid = torch.nn.functional.logsigmoid(
chosen_reward.float() -
rejected_reward.float()) * loss_mask
#average according to the interesting number of elements
num_elements_in_loss = loss_mask.sum().float()
loss += -(logsigmoid.sum() / num_elements_in_loss)

# log the c_ind / r_ind in chosen_mean_scores / rejected_mean_scores
c_ind_mask = get_last_before_padding(chosen_padding_mask, self.num_padding_at_beginning)
c_ind_mask = torch.where(chosen_padding_mask.sum() > self.num_padding_at_beginning, c_ind_mask, self.fallback_mask)
chosen_mean_score = (c_ind_mask.float() * chosen_reward.float()).sum()
c_ind_mask = get_last_before_padding(
chosen_padding_mask, self.num_padding_at_beginning)
c_ind_mask = torch.where(
chosen_padding_mask.sum() > self.num_padding_at_beginning,
c_ind_mask, self.fallback_mask)
chosen_mean_score = (c_ind_mask.float() *
chosen_reward.float()).sum()
chosen_mean_scores.append(chosen_mean_score)

r_ind_mask = get_last_before_padding(rejected_padding_mask, self.num_padding_at_beginning)
r_ind_mask = torch.where(rejected_padding_mask.sum() > self.num_padding_at_beginning, r_ind_mask, self.fallback_mask)
rejected_mean_score = (r_ind_mask.float() * rejected_reward.float()).sum()
r_ind_mask = get_last_before_padding(
rejected_padding_mask, self.num_padding_at_beginning)
r_ind_mask = torch.where(
rejected_padding_mask.sum() >
self.num_padding_at_beginning, r_ind_mask,
self.fallback_mask)
rejected_mean_score = (r_ind_mask.float() *
rejected_reward.float()).sum()
rejected_mean_scores.append(rejected_mean_score)
else:
c_inds = (chosen_id == self.PAD_ID).nonzero()
Expand All @@ -156,7 +176,8 @@ def get_last_before_padding(paddings, num_begin_padding):
# Check if there is any padding otherwise take length of sequence
r_inds = (rejected_id == self.PAD_ID).nonzero()
r_ind = r_inds[self.num_padding_at_beginning].item(
) if len(r_inds) > self.num_padding_at_beginning else seq_len
) if len(
r_inds) > self.num_padding_at_beginning else seq_len
end_ind = max(c_ind, r_ind)
divergence_ind = check_divergence[0]
assert divergence_ind > 0
Expand All @@ -165,8 +186,8 @@ def get_last_before_padding(paddings, num_begin_padding):
if self.loss_to_fp32:
c_truncated_reward = c_truncated_reward.float()
r_truncated_reward = r_truncated_reward.float()
loss += -torch.nn.functional.logsigmoid(c_truncated_reward -
r_truncated_reward).mean()
loss += -torch.nn.functional.logsigmoid(
c_truncated_reward - r_truncated_reward).mean()

chosen_mean_scores.append(
chosen_reward[c_ind - 1]) #use the end score for reference
Expand Down
15 changes: 11 additions & 4 deletions applications/DeepSpeed-Chat/dschat/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ def get_tokenizer(model_name_or_path, fast_tokenizer=True):
return tokenizer


def load_hf_tokenizer(model_name_or_path, fast_tokenizer=True, add_special_tokens=None):
def load_hf_tokenizer(model_name_or_path,
fast_tokenizer=True,
add_special_tokens=None):
if os.path.exists(model_name_or_path):
# Locally tokenizer loading has some issue, so we need to force download
model_json = os.path.join(model_name_or_path, "config.json")
Expand All @@ -109,7 +111,8 @@ def load_hf_tokenizer(model_name_or_path, fast_tokenizer=True, add_special_token
if add_special_tokens is not None:
add_special_tokens = [add_special_tokens] if isinstance(add_special_tokens, str) \
else add_special_tokens
tokenizer.add_special_tokens({'additional_special_tokens': add_special_tokens})
tokenizer.add_special_tokens(
{'additional_special_tokens': add_special_tokens})

return tokenizer

Expand Down Expand Up @@ -208,7 +211,10 @@ def get_optimizer_grouped_parameters(
model,
weight_decay,
lora_lr=5e-4,
no_decay_name_list=["bias", "layer_norm.weight", "layernorm.weight", "norm.weight", "ln_f.weight"],
no_decay_name_list=[
"bias", "layer_norm.weight", "layernorm.weight", "norm.weight",
"ln_f.weight"
],
lora_name_list=["lora_right_weight", "lora_left_weight"],
):
optimizer_grouped_parameters = [
Expand Down Expand Up @@ -313,7 +319,8 @@ def print_loss(epoch, step, steps_per_print, gas, loss, loss_sum, rank):
opt_step = step / gas
avg_loss = loss_sum / gas
print_rank_0(
f"[{datetime.now()}] epoch: {epoch} | step: {opt_step} | avg_loss: {avg_loss}", rank)
f"[{datetime.now()}] epoch: {epoch} | step: {opt_step} | avg_loss: {avg_loss}",
rank)
if step > 0 and step % gas == 0:
loss_sum.zero_()

Expand Down

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

Loading

0 comments on commit fa87108

Please sign in to comment.