Skip to content

Commit 9542030

Browse files
author
Bo Yang
committed
feat: add CE-U loss
1 parent 4d81a08 commit 9542030

File tree

5 files changed

+98
-2
lines changed

5 files changed

+98
-2
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
## 📖 Overview
2323

24-
We provide efficient and streamlined implementations of the TOFU, MUSE unlearning benchmarks while supporting 6 unlearning methods, 3+ datasets, 6+ evaluation metrics, and 7+ LLMs. Each of these can be easily extended to incorporate more variants.
24+
We provide efficient and streamlined implementations of the TOFU, MUSE unlearning benchmarks while supporting 7 unlearning methods, 3+ datasets, 6+ evaluation metrics, and 7+ LLMs. Each of these can be easily extended to incorporate more variants.
2525

2626
We invite the LLM unlearning community to collaborate by adding new benchmarks, unlearning methods, datasets and evaluation metrics here to expand OpenUnlearning's features, gain feedback from wider usage and drive progress in the field.
2727

@@ -35,7 +35,7 @@ We provide several variants for each of the components in the unlearning pipelin
3535
| **Component** | **Available Options** |
3636
|------------------------|----------------------|
3737
| **Benchmarks** | [TOFU](https://arxiv.org/abs/2401.06121), [MUSE](https://muse-bench.github.io/) |
38-
| **Unlearning Methods** | GradAscent, GradDiff, NPO, SimNPO, DPO, RMU |
38+
| **Unlearning Methods** | GradAscent, GradDiff, NPO, SimNPO, DPO, RMU, CE-U |
3939
| **Evaluation Metrics** | Verbatim Probability, Verbatim ROUGE, QA-ROUGE, MIA Attacks, TruthRatio, Model Utility |
4040
| **Datasets** | MUSE-News (BBC), MUSE-Books (Harry Potter), TOFU (different splits) |
4141
| **Model Families** | TOFU: LLaMA-3.2, LLaMA-3.1, LLaMA-2; MUSE: LLaMA-2, ICLM; Additional: Phi-3.5, Phi-1.5, Gemma |

configs/trainer/CEU.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
defaults:
2+
- finetune
3+
4+
handler: CEU
5+
method_args:
6+
ignore_first_n_answer_tokens: 1

src/trainer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from trainer.unlearn.dpo import DPO
1111
from trainer.unlearn.simnpo import SimNPO
1212
from trainer.unlearn.rmu import RMU
13+
from trainer.unlearn.ceu import CEU
1314

1415
TRAINER_REGISTRY: Dict[str, Any] = {}
1516

@@ -81,3 +82,4 @@ def load_trainer(
8182
_register_trainer(DPO)
8283
_register_trainer(SimNPO)
8384
_register_trainer(RMU)
85+
_register_trainer(CEU)

src/trainer/unlearn/ceu.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from trainer.unlearn.base import UnlearnTrainer
2+
from trainer.utils import compute_batch_ceu
3+
class CEU(UnlearnTrainer):
4+
def __init__(self, ignore_first_n_answer_tokens=1, *args, **kwargs):
5+
super().__init__(*args, **kwargs)
6+
self.ignore_first_n_answer_tokens = ignore_first_n_answer_tokens
7+
8+
def compute_loss(self, model, inputs, return_outputs=False):
9+
forget_inputs = inputs["forget"]
10+
loss, outputs = compute_batch_ceu(model, forget_inputs, ignore_first_n_answer_tokens=self.ignore_first_n_answer_tokens)
11+
return (loss, outputs) if return_outputs else loss

src/trainer/utils.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,80 @@ def compute_dpo_loss(model, ref_model, win_inputs=None, lose_inputs=None, beta=1
5555

5656
loss = -2 / beta * F.logsigmoid(beta * (win_log_ratio - lose_log_ratio)).mean()
5757
return loss, (win_outputs, lose_outputs)
58+
59+
60+
def cross_entropy_unlearning_loss(
61+
logits: torch.Tensor,
62+
labels: torch.Tensor,
63+
ignore_index: int = -100,
64+
) -> torch.Tensor:
65+
"""
66+
Implementation of Cross Entropy Unlearning Loss (CE-U).
67+
68+
This function creates a modified target distribution by setting the logit corresponding to the true label to negative infinity, effectively forcing the model to assign zero probability to the correct answer. The loss then minimizes the KL divergence between this target distribution and the model's output.
69+
70+
Args:
71+
logits: Model output logits with shape [batch_size, sequence_length, vocabulary_size]
72+
labels: Ground truth token indices with shape [batch_size, sequence_length]
73+
ignore_index: Token indices to ignore in the loss calculation (typically padding)
74+
75+
Returns:
76+
A scalar tensor representing the mean unlearning loss across valid positions
77+
"""
78+
batch_size, sequence_length, vocabulary_size = logits.shape
79+
# Extract valid logits and labels based on ignore_index.
80+
if ignore_index is not None:
81+
# Shape: [batch_size, sequence_length], boolean mask
82+
valid_mask = labels != ignore_index
83+
# Shape: [num_valid_positions, vocabulary_size]
84+
valid_logits = logits[valid_mask]
85+
# Shape: [num_valid_positions]
86+
valid_labels = labels[valid_mask]
87+
else:
88+
# Shape: [batch_size*sequence_length, vocabulary_size]
89+
valid_logits = logits.view(-1, vocabulary_size)
90+
# Shape: [batch_size*sequence_length]
91+
valid_labels = labels.view(-1)
92+
93+
# Create a copy of valid_logits to generate the target distribution
94+
# Shape: [num_valid_positions, vocabulary_size]
95+
valid_target_logits = valid_logits.detach().clone()
96+
97+
# Suppress the logits corresponding to the true token by setting them to -inf.
98+
# This ensures that the probability for the true token is effectively zero after softmax.
99+
valid_target_logits.scatter_(
100+
dim=-1,
101+
index=valid_labels.unsqueeze(-1), # Shape: [num_valid_positions, 1]
102+
value=float("-inf"),
103+
) # Result shape: [num_valid_positions, vocabulary_size]
104+
105+
# Apply softmax to generate the target probability distribution
106+
# Shape: [num_valid_positions, vocabulary_size]
107+
valid_target_probabilities = F.softmax(valid_target_logits, dim=-1)
108+
109+
# Compute the cross entropy loss between input logits and target probabilities
110+
# The loss is averaged over the valid positions and returns a scalar tensor
111+
return F.cross_entropy(
112+
input=valid_logits,
113+
target=valid_target_probabilities,
114+
)
115+
116+
117+
def compute_batch_ceu(model, inputs, ignore_first_n_answer_tokens=1):
118+
outputs = model(**inputs)
119+
logits = outputs.logits
120+
labels = inputs["labels"]
121+
122+
# Implement the trick to ignore the first n answer tokens mentioned in the footnote in the Training Settings section of arXiv:2503.01224
123+
valid_mask = labels != -100
124+
update_mask = (
125+
valid_mask.cumsum(dim=-1) <= ignore_first_n_answer_tokens
126+
) & valid_mask
127+
labels_without_first_n_answer_tokens = labels.masked_fill(update_mask, -100)
128+
129+
shifted_labels = labels_without_first_n_answer_tokens[..., 1:].contiguous()
130+
shifted_logits = logits[..., :-1, :].contiguous()
131+
loss = cross_entropy_unlearning_loss(
132+
shifted_logits, shifted_labels, ignore_index=-100
133+
)
134+
return loss, outputs

0 commit comments

Comments
 (0)