Skip to content

Commit

Permalink
[ORPO] add nll_target for orpo nll loss (#503)
Browse files Browse the repository at this point in the history
## Summary
add optional nll_target argument to calculate nll (needed for ORPO nll
loss)

- Hardware Type: <BLANK>
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence
  • Loading branch information
kashif authored Jan 9, 2025
1 parent 134a13e commit 9586a87
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 33 deletions.
10 changes: 6 additions & 4 deletions benchmark/scripts/benchmark_orpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,13 @@ def bench_memory_fused_linear_orpo_loss(

_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
nll_target = torch.randint(V, (B, T), dtype=torch.long, device=device)

def fwd():
if provider == "liger":
return liger_lm_head_orpo(_input, target)
return liger_lm_head_orpo(_input, target, nll_target)
elif provider == "huggingface":
return torch_lm_head_orpo(_input, target)
return torch_lm_head_orpo(_input, target, nll_target)

def full():
y = fwd()
Expand Down Expand Up @@ -91,12 +92,13 @@ def bench_speed_fused_linear_orpo_loss(

_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
nll_target = torch.randint(V, (B, T), dtype=torch.long, device=device)

def fwd():
if provider == "liger":
return liger_lm_head_orpo(_input, target)
return liger_lm_head_orpo(_input, target, nll_target)
elif provider == "huggingface":
return torch_lm_head_orpo(_input, target)
return torch_lm_head_orpo(_input, target, nll_target)

if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(
Expand Down
52 changes: 40 additions & 12 deletions src/liger_kernel/chunked_loss/fused_linear_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def forward(
alpha=1.0,
beta=0.1,
compute_nll_loss=True,
nll_target=None,
compiled=True,
use_ref_model=False,
ref_input=None,
Expand Down Expand Up @@ -58,6 +59,7 @@ def forward(
alpha (float): Weight for the NLL loss.
beta (float): Weight for the preference loss.
compute_nll_loss (bool): Whether to compute NLL loss.
nll_target (torch.Tensor, optional): Target tensor for NLL loss. Shape: (batch_size, seq_len). If not provided the target is used.
compiled (bool): Whether to use torch compile for chunk accumulation.
use_ref_model (bool): Whether to use a reference model for the alignment loss.
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
Expand Down Expand Up @@ -96,11 +98,12 @@ def forward(
use_ref_model=use_ref_model,
ref_weight=ref_weight,
ref_bias=ref_bias,
full_nll_target=nll_target,
average_log_prob=average_log_prob,
**loss_kwargs,
)

def fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk):
def fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk):
"""
Fused forward and backward pass for a chunk of input and target.
"""
Expand All @@ -111,13 +114,18 @@ def fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk):
target_chunk,
bias,
ref_input_chunk=ref_input_chunk,
chosen_nll_target_chunk=chosen_nll_target_chunk,
)
else:
return torch.func.grad_and_value(compute_loss, argnums=(0, 1), has_aux=True)(
input_chunk, weight, target_chunk, ref_input_chunk=ref_input_chunk
input_chunk,
weight,
target_chunk,
ref_input_chunk=ref_input_chunk,
chosen_nll_target_chunk=chosen_nll_target_chunk,
)

def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None):
def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None, chosen_nll_target_chunk=None):
if bias is not None:
(
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias),
Expand All @@ -132,7 +140,7 @@ def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None):
*aux_outputs,
),
),
) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk)
) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk)
grad_bias.add_(chunk_grad_bias) # accumulate bias gradient
else:
(
Expand All @@ -148,7 +156,7 @@ def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None):
*aux_outputs,
),
),
) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk)
) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk)

# Accumulate gradients
grad_weight.add_(chunk_grad_weight)
Expand Down Expand Up @@ -191,6 +199,9 @@ def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None):
_rejected_input_chunks = torch.chunk(_input[len_chosen:], chunks=chunks, dim=0)
_rejected_target_chunks = torch.chunk(target[len_chosen:], chunks=chunks, dim=0)

if nll_target is not None:
_chosen_nll_target_chunks = torch.chunk(nll_target[:len_chosen], chunks=chunks, dim=0)

if use_ref_model:
_ref_chosen_input_chunks = torch.chunk(ref_input[:len_chosen], chunks=chunks, dim=0)
_ref_rejected_input_chunks = torch.chunk(ref_input[len_chosen:], chunks=chunks, dim=0)
Expand All @@ -202,13 +213,15 @@ def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None):
rejected_target_chunk,
ref_chosen_input_chunk,
ref_rejected_input_chunk,
chosen_nll_target_chunk,
) in zip(
_chosen_input_chunks,
_rejected_input_chunks,
_chosen_target_chunks,
_rejected_target_chunks,
(_ref_chosen_input_chunks if use_ref_model else [None] * len(_chosen_input_chunks)),
(_ref_rejected_input_chunks if use_ref_model else [None] * len(_rejected_input_chunks)),
(_chosen_nll_target_chunks if nll_target is not None else [None] * len(_chosen_input_chunks)),
strict=False,
):
input_chunk = torch.cat([chosen_input_chunk, rejected_input_chunk], dim=0)
Expand All @@ -222,9 +235,10 @@ def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None):
torch._dynamo.mark_dynamic(target_chunk, 1)
torch._dynamo.mark_dynamic(target, 1)
torch._dynamo.mark_dynamic(ref_input_chunk, 1) if use_ref_model else None
torch._dynamo.mark_dynamic(chosen_nll_target_chunk, 1) if nll_target is not None else None

# accumulate loss, gradients, and metrics
accumulate_chunk(input_chunk, target_chunk, ref_input_chunk)
accumulate_chunk(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk)

# combine grad_chosen_inputs and grad_rejected_inputs
grad_inputs = grad_chosen_inputs + grad_rejected_inputs
Expand Down Expand Up @@ -258,7 +272,7 @@ def backward(ctx, *grad_output):
grad_weight = grad_weight * grad_output[0][0]
grad_bias = grad_bias * grad_output[0][0] if grad_bias is not None else None

return grad_input, grad_weight, None, grad_bias, None, None, None
return grad_input, grad_weight, None, grad_bias, None, None, None, None

@staticmethod
def chunk_forward(
Expand All @@ -268,6 +282,7 @@ def chunk_forward(
bias=None,
ignore_index=-100,
compute_nll_loss=True,
chosen_nll_target_chunk=None,
average_log_prob=True,
):
len_chosen_chunk = target_chunk.shape[0] // 2
Expand All @@ -278,9 +293,12 @@ def chunk_forward(

chosen_nll_loss = 0.0
if compute_nll_loss:
nll_labels = (
chosen_nll_target_chunk if chosen_nll_target_chunk is not None else target_chunk[:len_chosen_chunk]
)
chosen_nll_loss = F.nll_loss(
log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]),
target_chunk[:len_chosen_chunk].view(-1),
nll_labels.view(-1),
reduction="sum",
ignore_index=ignore_index,
)
Expand Down Expand Up @@ -324,6 +342,8 @@ def _compute_loss(
ref_input_chunk=None,
ref_weight=None,
ref_bias=None,
full_nll_target=None,
chosen_nll_target_chunk=None,
average_log_prob=True,
**loss_kwargs,
):
Expand All @@ -343,6 +363,8 @@ def _compute_loss(
use_ref_model (bool): Whether to use a reference model for the alignment loss.
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
full_nll_target (torch.Tensor, optional): Full target tensor for NLL loss. Shape: (batch_size, sequence_length).
chosen_nll_target_chunk (torch.Tensor, optional): Target tensor for NLL loss. Shape: (chunk_size, sequence_length) If not provided the target_chunk is used.
average_log_prob (bool): Whether to average log probabilities or the sum.
loss_kwargs (dict): Additional arguments for the loss function.
"""
Expand All @@ -359,9 +381,14 @@ def _compute_loss(
bias=bias,
ignore_index=ignore_index,
compute_nll_loss=compute_nll_loss,
chosen_nll_target_chunk=chosen_nll_target_chunk,
average_log_prob=average_log_prob,
)
chosen_nll_loss = chosen_nll_loss / (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
if full_nll_target is not None:
chosen_nll_loss = chosen_nll_loss / (full_nll_target[: full_nll_target.shape[0] // 2] != ignore_index).sum()
else:
chosen_nll_loss = chosen_nll_loss / (full_target[: full_target.shape[0] // 2] != ignore_index).sum()

chosen_logits_mean = chosen_logits.sum() / (full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0])
rejected_logits_mean = rejected_logits.sum() / (
full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
Expand All @@ -372,16 +399,17 @@ def _compute_loss(
(
ref_chosen_logps,
ref_rejected_logps,
ref_chosen_logits,
ref_rejected_logits,
ref_chosen_nll_loss,
_,
_,
_,
) = LigerFusedLinearPreferenceBase.chunk_forward(
ref_input_chunk,
ref_weight,
target_chunk,
ref_bias,
ignore_index=ignore_index,
compute_nll_loss=False, # We don't need NLL loss for the reference model
chosen_nll_target_chunk=None,
average_log_prob=average_log_prob,
)
loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
Expand Down
7 changes: 5 additions & 2 deletions src/liger_kernel/chunked_loss/orpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def forward(
ignore_index=-100,
beta=0.1,
compute_nll_loss=True,
nll_target=None,
compiled=True,
):
return LigerFusedLinearPreferenceBase.forward(
Expand All @@ -64,13 +65,14 @@ def forward(
ignore_index=ignore_index,
beta=beta,
compute_nll_loss=compute_nll_loss,
nll_target=nll_target,
compiled=compiled,
)

@staticmethod
def backward(ctx, *grad_output):
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
return *grads, None, None, None, None
return *grads, None, None, None, None, None


class LigerFusedLinearORPOLoss(torch.nn.Module):
Expand All @@ -96,7 +98,7 @@ def __init__(
self.compute_nll_loss = compute_nll_loss
self.compiled = compiled

def forward(self, lin_weight, _input, target, bias=None):
def forward(self, lin_weight, _input, target, bias=None, nll_target=None):
return LigerFusedLinearORPOFunction.apply(
_input,
lin_weight,
Expand All @@ -105,5 +107,6 @@ def forward(self, lin_weight, _input, target, bias=None):
self.ignore_index,
self.beta,
self.compute_nll_loss,
nll_target,
self.compiled,
)
20 changes: 16 additions & 4 deletions src/liger_kernel/transformers/trainer/orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,13 @@ def concatenated_forward(
if self.aux_loss_enabled:
model_kwargs["output_router_logits"] = True

if self.is_encoder_decoder:
labels = concatenated_batch["concatenated_labels"].clone()
else:
labels = concatenated_batch["concatenated_input_ids"].clone()
attention_mask = concatenated_batch["concatenated_attention_mask"]
labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id)

if isinstance(model, FullyShardedDataParallel):
outputs = _FSDPForwardRedirection()(
model,
Expand All @@ -114,15 +121,20 @@ def concatenated_forward(

orpo_loss_fn = LigerFusedLinearORPOLoss(ignore_index=self.label_pad_token_id, beta=self.beta)

def orpo_partial(lm_head, last_hidden_state, concatenated_labels):
return orpo_loss_fn(lm_head.weight, last_hidden_state, concatenated_labels, lm_head.bias)
def orpo_partial(lm_head, last_hidden_state, concatenated_labels, nll_target):
return orpo_loss_fn(
lm_head.weight, last_hidden_state, concatenated_labels, lm_head.bias, nll_target=nll_target
)

orpo_loss, aux_outputs = _FSDPForwardRedirection()(
model,
orpo_partial,
model.lm_head,
outputs.last_hidden_state,
concatenated_batch["concatenated_labels"],
outputs.last_hidden_state[:, :-1] if not self.is_encoder_decoder else outputs.last_hidden_state,
concatenated_batch["concatenated_labels"][:, 1:]
if not self.is_encoder_decoder
else concatenated_batch["concatenated_labels"],
labels[:, 1:] if not self.is_encoder_decoder else labels,
)
# if aux_loss_enabled, add the aux_loss to the orpo_loss
if self.aux_loss_enabled:
Expand Down
18 changes: 10 additions & 8 deletions test/chunked_loss/test_orpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ def __init__(
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype)
self.orpo_loss = HFORPOLoss(ignore_index=ignore_index, beta=beta).get_batch_loss_metrics

def forward(self, x, y):
return self.orpo_loss(self.lin.weight, x, y, self.lin.bias)
def forward(self, x, y, nll_target=None):
return self.orpo_loss(self.lin.weight, x, y, self.lin.bias, nll_target=nll_target)


class LigerLMHeadORPO(torch.nn.Module):
Expand All @@ -104,8 +104,8 @@ def __init__(
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype)
self.orpo_loss = LigerFusedLinearORPOLoss(ignore_index=ignore_index, beta=beta)

def forward(self, x, y):
return self.orpo_loss(self.lin.weight, x, y, self.lin.bias)
def forward(self, x, y, nll_target=None):
return self.orpo_loss(self.lin.weight, x, y, self.lin.bias, nll_target=nll_target)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -164,13 +164,15 @@ def test_correctness(B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index,
device=device,
dtype=torch.long,
)
nll_target = torch.randint(0, V, (B, T), device=device, dtype=torch.long)

# Assign some random number of elements as ignore_index
num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item()
indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign]
target.view(-1)[indices_to_assign] = ignore_index

loss1, aggregated_aux_outputs1 = torch_lm_head_orpo(input1, target)
loss2, aggregated_aux_outputs2 = liger_lm_head_orpo(input2, target)
loss1, aggregated_aux_outputs1 = torch_lm_head_orpo(input1, target, nll_target)
loss2, aggregated_aux_outputs2 = liger_lm_head_orpo(input2, target, nll_target)

assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol)

Expand Down Expand Up @@ -244,8 +246,8 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias):
bias1 = _bias.detach().clone().requires_grad_(True) if bias else None
bias2 = _bias.detach().clone().requires_grad_(True) if bias else None

loss1, aggregated_aux_outputs1 = LigerFusedLinearORPOFunction.apply(input1, weight1, target, bias1)
loss2, aggregated_aux_outputs2 = liger_fused_linear_orpo(input2, weight2, target, bias2)
loss1, _ = LigerFusedLinearORPOFunction.apply(input1, weight1, target, bias1)
loss2, _ = liger_fused_linear_orpo(input2, weight2, target, bias2)

assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol)

Expand Down
8 changes: 5 additions & 3 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,8 +406,9 @@ def concatenated_forward(
_input: torch.FloatTensor,
weight: torch.FloatTensor,
target: torch.LongTensor,
bias: torch.FloatTensor = None,
bias: torch.FloatTensor | None = None,
average_log_prob: bool = True,
nll_target: torch.LongTensor | None = None,
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
Expand All @@ -430,7 +431,7 @@ def cross_entropy_loss(logits, labels):
loss = loss_fct(logits, labels)
return loss

labels = target
labels = nll_target if nll_target is not None else target
chosen_nll_loss = torch.tensor(0.0, device=all_logits.device)
if self.compute_nll_loss:
chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
Expand Down Expand Up @@ -465,10 +466,11 @@ def get_batch_loss_metrics(
ref_weight: torch.FloatTensor = None,
ref_bias: torch.FloatTensor = None,
average_log_prob: bool = True,
nll_target: torch.LongTensor = None,
):
"""Compute the loss metrics for the given batch of inputs for train or test."""

forward_output = self.concatenated_forward(_input, weight, target, bias, average_log_prob)
forward_output = self.concatenated_forward(_input, weight, target, bias, average_log_prob, nll_target)
(
policy_chosen_logps,
policy_rejected_logps,
Expand Down

0 comments on commit 9586a87

Please sign in to comment.