From 192abbcf90d30ce103e153f5fe3d34bbee1a4623 Mon Sep 17 00:00:00 2001 From: yuuee-www <969761826@qq.com> Date: Mon, 24 Mar 2025 21:34:21 +0800 Subject: [PATCH 1/4] add gru component --- configs/trainer/GRU.yaml | 7 + src/trainer/__init__.py | 3 + src/trainer/unlearn/gru.py | 368 +++++++++++++++++++++++++++++++++++++ 3 files changed, 378 insertions(+) create mode 100644 configs/trainer/GRU.yaml create mode 100644 src/trainer/unlearn/gru.py diff --git a/configs/trainer/GRU.yaml b/configs/trainer/GRU.yaml new file mode 100644 index 0000000..326eb87 --- /dev/null +++ b/configs/trainer/GRU.yaml @@ -0,0 +1,7 @@ +defaults: + - finetune + +handler: GRU +method_args: + gamma_gru: 0.8 + \ No newline at end of file diff --git a/src/trainer/__init__.py b/src/trainer/__init__.py index 7e195fa..00fe6b1 100644 --- a/src/trainer/__init__.py +++ b/src/trainer/__init__.py @@ -10,6 +10,8 @@ from trainer.unlearn.dpo import DPO from trainer.unlearn.simnpo import SimNPO from trainer.unlearn.rmu import RMU +from trainer.unlearn.gru import GRU + TRAINER_REGISTRY: Dict[str, Any] = {} @@ -81,3 +83,4 @@ def load_trainer( _register_trainer(DPO) _register_trainer(SimNPO) _register_trainer(RMU) +_register_trainer(GRU) diff --git a/src/trainer/unlearn/gru.py b/src/trainer/unlearn/gru.py new file mode 100644 index 0000000..e7fddd5 --- /dev/null +++ b/src/trainer/unlearn/gru.py @@ -0,0 +1,368 @@ +from transformers.utils import is_sagemaker_mp_enabled +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union +from transformers.training_args import OptimizerNames, ParallelMode, TrainingArguments +from collections.abc import Mapping +from pathlib import Path +import logging + +logger = logging.getLogger(__name__) + +import torch +from torch import nn +from copy import deepcopy +from packaging import version +from trainer.base import FinetuneTrainer + +from transformers.trainer_pt_utils import ( + nested_detach, +) + + +from transformers.utils import ( + is_sagemaker_mp_enabled, +) + +from accelerate.utils import ( + is_deepspeed_available, +) + +if is_sagemaker_mp_enabled(): + from smdistributed.modelparallel import __version__ as SMP_VERSION + + IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10") + + from transformers.trainer_pt_utils import ( + smp_forward_only, + smp_nested_concat, + ) +else: + IS_SAGEMAKER_MP_POST_1_10 = False + +if is_deepspeed_available(): + import deepspeed + + + + +from trainer.unlearn.base import UnlearnTrainer +import numpy as np + + +def print_metrics(step, metrics): + """ + Print current training metrics in a formatted way. + + Args: + epoch (int): Current epoch or iteration number. + metrics (dict): A dictionary containing metric names as keys and their current values. + """ + # Prepare the formatted string + metrics_string = ', '.join([f"{key}: {value:.4f}" for key, value in metrics.items()]) + print(f"Step {step}: {metrics_string}") + +class GRU(UnlearnTrainer): + def __init__(self, gamma_gru=0.8, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.gamma_gru = gamma_gru + self.gradient_accumulation_steps = kwargs["args"].gradient_accumulation_steps + + self.dotp_retain = None + self.flattened_gradient = 0.0 + self.flattened_memory = 0.0 + self.flattened_memory_old = 0.0 + self.flattened_memory_accumulation = 0.0 + self.structure_map = None + + self.steps = 0 + + self.gradient_accum = {} + + self.memory_grad = {} + + def orthogonal_component(self, g, g1): + + g1g1 = self.compute_total_gradient_dot_product(g1, self.structure_map, g1, self.structure_map) + gg1 = self.dotp_retain + print(gg1/g1g1) + projection = gg1/g1g1* g1 + orthogonal = g - projection + + return orthogonal + + def store_grads(self, model, loss=None, typ=None): + """ + Accumulates gradients of specified layers, preserving their original shapes. + Optionally, adjusts which layers are trainable just before computing gradients. + + Args: + model (torch.nn.Module): The model from which to store gradients. + loss (torch.Tensor, optional): The loss tensor to perform backward operation. If provided, will compute gradients. + + Returns: + None: Modifies internal tensors to store accumulated gradients. + """ + + # Perform backward pass if a loss tensor is provided + if loss: + loss = loss / self.gradient_accumulation_steps + loss.backward() + + # Loop through parameters and accumulate gradients + for name, param in model.named_parameters(): + if param.requires_grad: + if param.grad is None: + param.grad = torch.zeros_like(param) + + # Choose the correct dictionary based on 'typ' + if typ == "objective": + target_dict = self.gradient_accum + elif typ == "retain": + target_dict = self.memory_grad + else: + raise ValueError("Invalid type specified for gradient storage") + + # Initialize the dictionary key if it doesn't exist + if name not in target_dict: + target_dict[name] = torch.zeros_like(param.grad, device=param.grad.device) # Initialize on the same device + + # Accumulate the gradients + target_dict[name] += param.grad.detach() + + if loss: + model.zero_grad() + + def flatten_and_store_grads(self): + """ + Flattens accumulated gradients from different gradient dictionaries, moves them to CPU, + and stores them along with a structure map for each type of gradient. + """ + + # Helper function to flatten gradients, move to CPU, and record their original structure + def flatten_to_cpu_and_record_structure(gradient_dict): + flattened_grads = [] + structure_map = [] + for name, grad in gradient_dict.items(): + if grad is not None: + grad_flat = grad.view(-1) + flattened_grads.append(grad_flat) + structure_map.append((name, grad.shape)) + + if flattened_grads: + return torch.cat(flattened_grads).to('cpu'), structure_map + else: + return torch.tensor([], dtype=torch.float32).to('cpu'), [] + + + self.flattened_gradient, self.structure_map = flatten_to_cpu_and_record_structure(self.gradient_accum) + + self.flattened_memory_accumulation, _ = flatten_to_cpu_and_record_structure(self.memory_grad) + + def compute_total_gradient_dot_product(self, flattened_grads1, structure_map1, flattened_grads2, structure_map2): + """ + Computes the total dot product between gradients from two sets of flattened gradients and their respective structure maps. + + Args: + flattened_grads1 (torch.Tensor): The first flattened gradient tensor. + structure_map1 (list): A list of tuples containing parameter names and their corresponding shapes for the first set of gradients. + flattened_grads2 (torch.Tensor): The second flattened gradient tensor. + structure_map2 (list): A list of tuples containing parameter names and their corresponding shapes for the second set of gradients. + + Returns: + float: The total dot product summed across all matching layers. + """ + #assert len(structure_map1) == len(structure_map2), "Both gradient structures must contain the same number of elements." + + total_dot_product = 0.0 + index = 0 + + # Ensure both gradient tensors are on the same device + flattened_grads1 = flattened_grads1.to('cuda') + flattened_grads2 = flattened_grads2.to('cuda') + + # for ((name1, shape1), (name2, shape2)) in zip(structure_map1, structure_map2): + # assert name1 == name2 and shape1 == shape2, f"Gradient mismatch: {name1} vs {name2} or {shape1} vs {shape2}" + + for ((name1, shape1), (name2, shape2)) in zip(structure_map1, structure_map2): + assert shape1 == shape2, f"Gradient mismatch: {name1} vs {name2} or {shape1} vs {shape2}" + + size = np.prod(shape1) # Total number of elements in this layer's gradient + grad_slice1 = flattened_grads1[index:index + size].view(shape1) + grad_slice2 = flattened_grads2[index:index + size].view(shape2) + + # Compute the dot product of the two gradient slices + dot_product = (grad_slice1 * grad_slice2).sum() + total_dot_product += dot_product.item() + + index += size + + return total_dot_product + + def restore_gradients_from_flat(self, model): + """ + Restores gradients to the model's parameters directly from a flattened gradient tensor. + + Args: + model (torch.nn.Module): The model to which the gradients will be restored. + flattened_grads (torch.Tensor): The flattened gradient tensor. + structure_map (list): A list of tuples containing parameter names and their corresponding shapes. + """ + + index = 0 # Index to track position in the flattened gradient tensor + + for name, shape in self.structure_map: + size = np.prod(shape) # Total number of elements in this gradient + if size == 0: # Skip layers with no parameters + continue + + # Extract the relevant slice from the flattened gradient tensor + grad_slice = self.flattened_gradient[index:index + size].view(shape) + + # Find the corresponding parameter in the model + param = next((p for n, p in model.named_parameters() if n == name), None) + if param.requires_grad: + # Check if the shape of the extracted gradient matches the parameter's shape + if grad_slice.shape != param.shape: + raise ValueError(f"Gradient shape mismatch for {name}: expected {param.shape}, got {grad_slice.shape}") + + # Set the parameter's gradient to the extracted slice + param.grad = grad_slice.to(param.device) + + index += size # Update index to the start of the next gradient slice + + if index != self.flattened_gradient.numel(): + raise ValueError("Total number of gradient elements does not match the length of the flattened gradient tensor.") + + def pipeline(self): + if self.dotp_retain<0: + print("dotp_retain:",self.dotp_retain) + self.flattened_gradient = self.orthogonal_component(self.flattened_gradient, self.flattened_memory) + torch.cuda.empty_cache() + + def compute_retain_loss(self, model, retain_inputs): + retain_outputs = model(**retain_inputs) + retain_loss = 0.0 + retain_loss += retain_outputs.loss + return retain_loss + + def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: + """ + Perform a training step on a batch of inputs. + + Subclass and override to inject custom behavior. + + Args: + model (`nn.Module`): + The model to train. + inputs (`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument `labels`. Check your model's documentation for all accepted arguments. + + Return: + `torch.Tensor`: The tensor with training loss on this batch. + """ + model.train() + if hasattr(self.optimizer, "train") and callable(self.optimizer.train): + self.optimizer.train() + + inputs = self._prepare_inputs(inputs) + if is_sagemaker_mp_enabled(): + loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps) + return loss_mb.reduce_mean().detach().to(self.args.device) + + with self.compute_loss_context_manager(): + loss = self.compute_loss(model, inputs) + + del inputs + if ( + self.args.torch_empty_cache_steps is not None + and self.state.global_step % self.args.torch_empty_cache_steps == 0 + ): + if is_torch_xpu_available(): + torch.xpu.empty_cache() + elif is_torch_mlu_available(): + torch.mlu.empty_cache() + elif is_torch_musa_available(): + torch.musa.empty_cache() + elif is_torch_npu_available(): + torch.npu.empty_cache() + elif is_torch_mps_available(min_version="2.0"): + torch.mps.empty_cache() + else: + torch.cuda.empty_cache() + + kwargs = {} + + # For LOMO optimizers you need to explicitly use the learnign rate + if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: + kwargs["learning_rate"] = self._get_learning_rate() + + if self.args.n_gpu > 1: + loss = loss.mean() # mean() to average on multi-gpu parallel training + + if self.use_apex: + with amp.scale_loss(loss, self.optimizer) as scaled_loss: + scaled_loss.backward() + else: + # GRU overwrite + #self.accelerator.backward(loss, **kwargs) + + torch.cuda.empty_cache() + + if self.steps % self.gradient_accumulation_steps == 0: + + # Flatten and move accumulated gradients to CPU before clearing + self.flatten_and_store_grads() + self.gradient_accum = {} + self.memory_grad = {} + + self.flattened_memory = self.gamma_gru * self.flattened_memory_accumulation + (1 - self.gamma_gru) * self.flattened_memory_old + self.flattened_memory_old = self.flattened_memory + self.dotp_retain = self.compute_total_gradient_dot_product(self.flattened_gradient, self.structure_map, + self.flattened_memory, self.structure_map) + self.pipeline() + + self.restore_gradients_from_flat(model) + self.flattened_memory_accumulation = 0 + torch.cuda.empty_cache() + + return loss.detach() / self.args.gradient_accumulation_steps + + def compute_loss(self, model, inputs, return_outputs=False): + + + forget_inputs = inputs["forget"] + forget_inputs = { + "input_ids": forget_inputs["input_ids"], + "attention_mask": forget_inputs["attention_mask"], + "labels": forget_inputs["labels"], + } + + forget_outputs = model(**forget_inputs) + forget_loss = -forget_outputs.loss + del forget_outputs + self.store_grads(model, loss=forget_loss, typ = "objective") + + retain_inputs = inputs["retain"] + retain_inputs = { + "input_ids": retain_inputs["input_ids"], + "attention_mask": retain_inputs["attention_mask"], + "labels": retain_inputs["labels"], + } + retain_loss = self.compute_retain_loss(model=model, retain_inputs=retain_inputs) + self.store_grads(model, loss=retain_loss, typ = "retain") + + loss = forget_loss + self.steps +=1 + + + metrics = { + 'Loss': loss, + 'retain_loss': retain_loss, + 'forget_loss': forget_loss + } + print_metrics(self.steps, metrics) + + return (loss, forget_outputs) if return_outputs else loss \ No newline at end of file From d2bb6cd6f116563ad58c1819b5b3f8cabe0ba583 Mon Sep 17 00:00:00 2001 From: yuuee-www <969761826@qq.com> Date: Wed, 26 Mar 2025 12:39:27 +0800 Subject: [PATCH 2/4] Cleaned up code: removed unnecessary print statements --- src/trainer/unlearn/gru.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/trainer/unlearn/gru.py b/src/trainer/unlearn/gru.py index e7fddd5..3eba624 100644 --- a/src/trainer/unlearn/gru.py +++ b/src/trainer/unlearn/gru.py @@ -84,7 +84,7 @@ def orthogonal_component(self, g, g1): g1g1 = self.compute_total_gradient_dot_product(g1, self.structure_map, g1, self.structure_map) gg1 = self.dotp_retain - print(gg1/g1g1) + #print(gg1/g1g1) projection = gg1/g1g1* g1 orthogonal = g - projection @@ -235,7 +235,7 @@ def restore_gradients_from_flat(self, model): def pipeline(self): if self.dotp_retain<0: - print("dotp_retain:",self.dotp_retain) + #print("dotp_retain:",self.dotp_retain) self.flattened_gradient = self.orthogonal_component(self.flattened_gradient, self.flattened_memory) torch.cuda.empty_cache() @@ -306,7 +306,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: - # GRU overwrite + # Overwriting with GRU #self.accelerator.backward(loss, **kwargs) torch.cuda.empty_cache() @@ -332,7 +332,6 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, def compute_loss(self, model, inputs, return_outputs=False): - forget_inputs = inputs["forget"] forget_inputs = { "input_ids": forget_inputs["input_ids"], @@ -363,6 +362,6 @@ def compute_loss(self, model, inputs, return_outputs=False): 'retain_loss': retain_loss, 'forget_loss': forget_loss } - print_metrics(self.steps, metrics) + #print_metrics(self.steps, metrics) return (loss, forget_outputs) if return_outputs else loss \ No newline at end of file From 567b503756106f617bd9ff052b45b865370cb3ee Mon Sep 17 00:00:00 2001 From: yuuee-www <969761826@qq.com> Date: Wed, 9 Apr 2025 13:17:04 +0800 Subject: [PATCH 3/4] Refactor GRU trainer for enhanced readability and add 'forget_loss_type' configuration with NPO w/ GRU method integration. --- community/methods/GRU/README.md | 49 ++++++ community/methods/GRU/run.sh | 45 +++++ configs/trainer/GRU.yaml | 2 +- src/trainer/unlearn/gru.py | 300 ++++++++++++-------------------- 4 files changed, 204 insertions(+), 192 deletions(-) create mode 100644 community/methods/GRU/README.md create mode 100644 community/methods/GRU/run.sh diff --git a/community/methods/GRU/README.md b/community/methods/GRU/README.md new file mode 100644 index 0000000..fcecf9e --- /dev/null +++ b/community/methods/GRU/README.md @@ -0,0 +1,49 @@ +# TITLE + +- **Paper Title**: GRU: Mitigating the Trade-off Between Unlearning and Retention for Large Language Models +- **Authors**: Yue Wang, Qizhou Wang, Feng Liu, Wei Huang, Yali Du, Xiaojiang Du, Bo Han +- **Links**: [arXiv:2503.09117](https://arxiv.org/abs/2503.09117) + + +Provide a concise summary of your method details and its contributions. Please avoid using images to keep the repository size manageable. + +# Setup + +Please include the experimental setup such as + +- [ ] **Hyperparameters & Search Space:** Specify key hyperparameters, their search ranges, number of trials etc. +- [ ] **Computational Setup:** Mention the type and number of GPUs used. +- [ ] **DeepSpeed Configuration:** If any modifications were made to the default DeepSpeed config, specify them here. (You may include the config as a code block.) +- [ ] **Other Details:** Any additional setup details crucial for reproducing your method. + + +## Computational Setup + + +- **GPU Details**: NVIDIA A100 80GB +- **GPU Count**: The code for our method currently supports single GPU execution. We plan to enhance the codebase in the future to support multi-GPU configurations. + + +# Results + +To replicate your results, provide a `run.sh` script that contains all necessary commands to reproduce the final results. Ensure the script is well-documented. + +It would be appreciated if you can upload the final unlearned model(s) along with their `evals` folders to HuggingFace and provide the link(s) here. As the evaluations are updated, this would help us re-evaluate your model(s). + +# Citation + + +If you use this work, please cite: + +```bibtex + +@misc{wang2025grumitigatingtradeoffunlearning, + title={GRU: Mitigating the Trade-off between Unlearning and Retention for Large Language Models}, + author={Yue Wang and Qizhou Wang and Feng Liu and Wei Huang and Yali Du and Xiaojiang Du and Bo Han}, + year={2025}, + eprint={2503.09117}, + archivePrefix={arXiv}, + primaryClass={cs.LG}, + url={https://arxiv.org/abs/2503.09117}, +} +``` \ No newline at end of file diff --git a/community/methods/GRU/run.sh b/community/methods/GRU/run.sh new file mode 100644 index 0000000..3d3d197 --- /dev/null +++ b/community/methods/GRU/run.sh @@ -0,0 +1,45 @@ +#!/bin/bash + +# GRU with GradAscent +CUDA_VISIBLE_DEVICES=0 python src/train.py \ + --config-name=unlearn.yaml \ + experiment=unlearn/tofu/default \ + forget_split=forget10 \ + retain_split=retain90 \ + trainer=GRU \ + task_name=gru_ga_forget10 \ + trainer.method_args.forget_loss_type=GradAscent \ + trainer.args.gradient_accumulation_steps=16 \ + trainer.args.per_device_train_batch_size=4 + +# Evaluation for GRU with GradAscent +CUDA_VISIBLE_DEVICES=0 python src/eval.py \ + experiment=eval/tofu/default.yaml \ + forget_split=forget10 \ + model=Llama-3.2-1B-Instruct \ + task_name=gru_ga_forget10 \ + model.model_args.pretrained_model_name_or_path=saves/unlearn/gru_ga_forget10 \ + paths.output_dir=saves/unlearn/gru_ga_forget10/evals \ + retain_logs_path=saves/eval/tofu_Llama-3.2-1B-Instruct_retain90/TOFU_EVAL.json + +# GRU with NPO +CUDA_VISIBLE_DEVICES=0 python src/train.py \ + --config-name=unlearn.yaml \ + experiment=unlearn/tofu/default \ + forget_split=forget10 \ + retain_split=retain90 \ + trainer=GRU \ + task_name=gru_npo_forget10 \ + trainer.method_args.forget_loss_type=NPO \ + trainer.args.gradient_accumulation_steps=16 \ + trainer.args.per_device_train_batch_size=4 + +# Evaluation for GRU with NPO +CUDA_VISIBLE_DEVICES=0 python src/eval.py \ + experiment=eval/tofu/default.yaml \ + forget_split=forget10 \ + model=Llama-3.2-1B-Instruct \ + task_name=gru_npo_forget10 \ + model.model_args.pretrained_model_name_or_path=saves/unlearn/gru_npo_forget10 \ + paths.output_dir=saves/unlearn/gru_npo_forget10/evals \ + retain_logs_path=saves/eval/tofu_Llama-3.2-1B-Instruct_retain90/TOFU_EVAL.json diff --git a/configs/trainer/GRU.yaml b/configs/trainer/GRU.yaml index 326eb87..decaa5c 100644 --- a/configs/trainer/GRU.yaml +++ b/configs/trainer/GRU.yaml @@ -4,4 +4,4 @@ defaults: handler: GRU method_args: gamma_gru: 0.8 - \ No newline at end of file + forget_loss_type: GradAscent \ No newline at end of file diff --git a/src/trainer/unlearn/gru.py b/src/trainer/unlearn/gru.py index 3eba624..ab1ed56 100644 --- a/src/trainer/unlearn/gru.py +++ b/src/trainer/unlearn/gru.py @@ -1,90 +1,39 @@ -from transformers.utils import is_sagemaker_mp_enabled -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union -from transformers.training_args import OptimizerNames, ParallelMode, TrainingArguments -from collections.abc import Mapping -from pathlib import Path -import logging - -logger = logging.getLogger(__name__) - +from typing import Any, Dict, Union import torch from torch import nn -from copy import deepcopy -from packaging import version -from trainer.base import FinetuneTrainer - -from transformers.trainer_pt_utils import ( - nested_detach, -) - - -from transformers.utils import ( - is_sagemaker_mp_enabled, -) - -from accelerate.utils import ( - is_deepspeed_available, -) - -if is_sagemaker_mp_enabled(): - from smdistributed.modelparallel import __version__ as SMP_VERSION - - IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10") - - from transformers.trainer_pt_utils import ( - smp_forward_only, - smp_nested_concat, - ) -else: - IS_SAGEMAKER_MP_POST_1_10 = False - -if is_deepspeed_available(): - import deepspeed - - - - from trainer.unlearn.base import UnlearnTrainer import numpy as np +from trainer.utils import compute_dpo_loss +from trainer.unlearn.grad_diff import GradDiff -def print_metrics(step, metrics): - """ - Print current training metrics in a formatted way. - - Args: - epoch (int): Current epoch or iteration number. - metrics (dict): A dictionary containing metric names as keys and their current values. - """ - # Prepare the formatted string - metrics_string = ', '.join([f"{key}: {value:.4f}" for key, value in metrics.items()]) - print(f"Step {step}: {metrics_string}") -class GRU(UnlearnTrainer): - def __init__(self, gamma_gru=0.8, *args, **kwargs): +class GRU(GradDiff,UnlearnTrainer): + def __init__(self, gamma_gru=0.8, forget_loss_type="GradAscent", *args, **kwargs): super().__init__(*args, **kwargs) self.gamma_gru = gamma_gru + self.forget_loss_type = forget_loss_type self.gradient_accumulation_steps = kwargs["args"].gradient_accumulation_steps + if self.ref_model is None and self.forget_loss_type == "NPO": + self.ref_model = self._prepare_ref_model(self.model) + #self.ref_model = self.model.to(self.args.device) - self.dotp_retain = None + # Initialization of internal variables to store gradients and computational states + self.dotp_retain = 0.0 self.flattened_gradient = 0.0 - self.flattened_memory = 0.0 - self.flattened_memory_old = 0.0 - self.flattened_memory_accumulation = 0.0 + self.flattened_retain = 0.0 + self.flattened_retain_prev = 0.0 + self.flattened_retain_accumulation = 0.0 self.structure_map = None - self.steps = 0 - - self.gradient_accum = {} - - self.memory_grad = {} + self.gradient_accum = {} + self.retain_grad = {} def orthogonal_component(self, g, g1): - - g1g1 = self.compute_total_gradient_dot_product(g1, self.structure_map, g1, self.structure_map) + """Compute the component of g orthogonal to g1.""" + g1g1 = self.compute_total_gradient_dot_product(g1, g1, self.structure_map) gg1 = self.dotp_retain - #print(gg1/g1g1) projection = gg1/g1g1* g1 orthogonal = g - projection @@ -92,21 +41,23 @@ def orthogonal_component(self, g, g1): def store_grads(self, model, loss=None, typ=None): """ - Accumulates gradients of specified layers, preserving their original shapes. - Optionally, adjusts which layers are trainable just before computing gradients. + Captures and stores gradients instead of applying them directly within the training loop. This method + allows for sophisticated gradient manipulations before they are used to update the model, substituting + the portion of `training_step` where gradients would typically be computed and immediately applied. Args: model (torch.nn.Module): The model from which to store gradients. loss (torch.Tensor, optional): The loss tensor to perform backward operation. If provided, will compute gradients. - - Returns: - None: Modifies internal tensors to store accumulated gradients. """ # Perform backward pass if a loss tensor is provided if loss: + + # if self.args.n_gpu > 1: + # loss = loss.mean() + loss = loss / self.gradient_accumulation_steps - loss.backward() + loss.backward() # Compute gradients # Loop through parameters and accumulate gradients for name, param in model.named_parameters(): @@ -118,7 +69,7 @@ def store_grads(self, model, loss=None, typ=None): if typ == "objective": target_dict = self.gradient_accum elif typ == "retain": - target_dict = self.memory_grad + target_dict = self.retain_grad else: raise ValueError("Invalid type specified for gradient storage") @@ -130,9 +81,9 @@ def store_grads(self, model, loss=None, typ=None): target_dict[name] += param.grad.detach() if loss: - model.zero_grad() + model.zero_grad() # Clear gradients after storage - def flatten_and_store_grads(self): + def flatten2cpu(self): """ Flattens accumulated gradients from different gradient dictionaries, moves them to CPU, and stores them along with a structure map for each type of gradient. @@ -154,19 +105,17 @@ def flatten_to_cpu_and_record_structure(gradient_dict): return torch.tensor([], dtype=torch.float32).to('cpu'), [] - self.flattened_gradient, self.structure_map = flatten_to_cpu_and_record_structure(self.gradient_accum) - - self.flattened_memory_accumulation, _ = flatten_to_cpu_and_record_structure(self.memory_grad) + self.flattened_gradient, self.structure_map = flatten_to_cpu_and_record_structure(self.gradient_accum) + self.flattened_retain_accumulation, _ = flatten_to_cpu_and_record_structure(self.retain_grad) - def compute_total_gradient_dot_product(self, flattened_grads1, structure_map1, flattened_grads2, structure_map2): + def compute_total_gradient_dot_product(self, flattened_grads1, flattened_grads2, structure_map): """ Computes the total dot product between gradients from two sets of flattened gradients and their respective structure maps. Args: flattened_grads1 (torch.Tensor): The first flattened gradient tensor. - structure_map1 (list): A list of tuples containing parameter names and their corresponding shapes for the first set of gradients. flattened_grads2 (torch.Tensor): The second flattened gradient tensor. - structure_map2 (list): A list of tuples containing parameter names and their corresponding shapes for the second set of gradients. + structure_map (list): A list of tuples containing parameter names and their corresponding shapes for the second set of gradients. Returns: float: The total dot product summed across all matching layers. @@ -183,7 +132,7 @@ def compute_total_gradient_dot_product(self, flattened_grads1, structure_map1, f # for ((name1, shape1), (name2, shape2)) in zip(structure_map1, structure_map2): # assert name1 == name2 and shape1 == shape2, f"Gradient mismatch: {name1} vs {name2} or {shape1} vs {shape2}" - for ((name1, shape1), (name2, shape2)) in zip(structure_map1, structure_map2): + for ((name1, shape1), (name2, shape2)) in zip(structure_map, structure_map): assert shape1 == shape2, f"Gradient mismatch: {name1} vs {name2} or {shape1} vs {shape2}" size = np.prod(shape1) # Total number of elements in this layer's gradient @@ -198,16 +147,11 @@ def compute_total_gradient_dot_product(self, flattened_grads1, structure_map1, f return total_dot_product - def restore_gradients_from_flat(self, model): + def restore_gradients(self, model): """ - Restores gradients to the model's parameters directly from a flattened gradient tensor. + Restores gradients to the model's parameters directly from self.flattened_gradient. - Args: - model (torch.nn.Module): The model to which the gradients will be restored. - flattened_grads (torch.Tensor): The flattened gradient tensor. - structure_map (list): A list of tuples containing parameter names and their corresponding shapes. """ - index = 0 # Index to track position in the flattened gradient tensor for name, shape in self.structure_map: @@ -234,9 +178,9 @@ def restore_gradients_from_flat(self, model): raise ValueError("Total number of gradient elements does not match the length of the flattened gradient tensor.") def pipeline(self): - if self.dotp_retain<0: + if self.dotp_retain < 0: #print("dotp_retain:",self.dotp_retain) - self.flattened_gradient = self.orthogonal_component(self.flattened_gradient, self.flattened_memory) + self.flattened_gradient = self.orthogonal_component(self.flattened_gradient, self.flattened_retain) torch.cuda.empty_cache() def compute_retain_loss(self, model, retain_inputs): @@ -246,122 +190,96 @@ def compute_retain_loss(self, model, retain_inputs): return retain_loss def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: - """ - Perform a training step on a batch of inputs. - - Subclass and override to inject custom behavior. - - Args: - model (`nn.Module`): - The model to train. - inputs (`Dict[str, Union[torch.Tensor, Any]]`): - The inputs and targets of the model. - - The dictionary will be unpacked before being fed to the model. Most models expect the targets under the - argument `labels`. Check your model's documentation for all accepted arguments. - - Return: - `torch.Tensor`: The tensor with training loss on this batch. + """Overridden training_step to include custom GRU logic. + + Notes: + - Gradient computation via backward pass has already been performed by `store_grads`. + - This method performs additional operations on the stored gradients, including flattening gradients, smoothing retain gradients via EMA, and adjusting + gradients by projection. + - After these custom manipulations, modified gradients are restored back to model parameters before optimization. + """ model.train() if hasattr(self.optimizer, "train") and callable(self.optimizer.train): self.optimizer.train() inputs = self._prepare_inputs(inputs) - if is_sagemaker_mp_enabled(): - loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps) - return loss_mb.reduce_mean().detach().to(self.args.device) - with self.compute_loss_context_manager(): loss = self.compute_loss(model, inputs) - del inputs - if ( - self.args.torch_empty_cache_steps is not None - and self.state.global_step % self.args.torch_empty_cache_steps == 0 - ): - if is_torch_xpu_available(): - torch.xpu.empty_cache() - elif is_torch_mlu_available(): - torch.mlu.empty_cache() - elif is_torch_musa_available(): - torch.musa.empty_cache() - elif is_torch_npu_available(): - torch.npu.empty_cache() - elif is_torch_mps_available(min_version="2.0"): - torch.mps.empty_cache() - else: - torch.cuda.empty_cache() - - kwargs = {} + torch.cuda.empty_cache() - # For LOMO optimizers you need to explicitly use the learnign rate - if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: - kwargs["learning_rate"] = self._get_learning_rate() + if self.steps % self.gradient_accumulation_steps == 0: - if self.args.n_gpu > 1: - loss = loss.mean() # mean() to average on multi-gpu parallel training + # Flatten and move accumulated gradients to CPU before clearing + self.flatten2cpu() + self.gradient_accum = {} + self.retain_grad = {} - if self.use_apex: - with amp.scale_loss(loss, self.optimizer) as scaled_loss: - scaled_loss.backward() - else: - # Overwriting with GRU - #self.accelerator.backward(loss, **kwargs) + # For Stable Estimation + self.flattened_retain = self.gamma_gru * self.flattened_retain_accumulation + (1 - self.gamma_gru) * self.flattened_retain_prev + self.flattened_retain_prev = self.flattened_retain + + self.dotp_retain = self.compute_total_gradient_dot_product(self.flattened_gradient, self.flattened_retain, self.structure_map) + self.pipeline() + self.restore_gradients(model) + self.flattened_retain_accumulation = 0 torch.cuda.empty_cache() - if self.steps % self.gradient_accumulation_steps == 0: - - # Flatten and move accumulated gradients to CPU before clearing - self.flatten_and_store_grads() - self.gradient_accum = {} - self.memory_grad = {} - - self.flattened_memory = self.gamma_gru * self.flattened_memory_accumulation + (1 - self.gamma_gru) * self.flattened_memory_old - self.flattened_memory_old = self.flattened_memory - self.dotp_retain = self.compute_total_gradient_dot_product(self.flattened_gradient, self.structure_map, - self.flattened_memory, self.structure_map) - self.pipeline() - - self.restore_gradients_from_flat(model) - self.flattened_memory_accumulation = 0 - torch.cuda.empty_cache() - return loss.detach() / self.args.gradient_accumulation_steps def compute_loss(self, model, inputs, return_outputs=False): - forget_inputs = inputs["forget"] - forget_inputs = { - "input_ids": forget_inputs["input_ids"], - "attention_mask": forget_inputs["attention_mask"], - "labels": forget_inputs["labels"], - } - - forget_outputs = model(**forget_inputs) - forget_loss = -forget_outputs.loss - del forget_outputs - self.store_grads(model, loss=forget_loss, typ = "objective") - - retain_inputs = inputs["retain"] - retain_inputs = { - "input_ids": retain_inputs["input_ids"], - "attention_mask": retain_inputs["attention_mask"], - "labels": retain_inputs["labels"], - } - retain_loss = self.compute_retain_loss(model=model, retain_inputs=retain_inputs) - self.store_grads(model, loss=retain_loss, typ = "retain") - - loss = forget_loss - self.steps +=1 - - - metrics = { - 'Loss': loss, - 'retain_loss': retain_loss, - 'forget_loss': forget_loss + if self.forget_loss_type == "GradAscent": + + forget_inputs = inputs["forget"] + forget_inputs = { + "input_ids": forget_inputs["input_ids"], + "attention_mask": forget_inputs["attention_mask"], + "labels": forget_inputs["labels"], } - #print_metrics(self.steps, metrics) + + forget_outputs = model(**forget_inputs) + forget_loss = -forget_outputs.loss + del forget_outputs + self.store_grads(model, loss=forget_loss, typ = "objective") + + retain_inputs = inputs["retain"] + retain_inputs = { + "input_ids": retain_inputs["input_ids"], + "attention_mask": retain_inputs["attention_mask"], + "labels": retain_inputs["labels"], + } + retain_loss = self.compute_retain_loss(model=model, retain_inputs=retain_inputs) + self.store_grads(model, loss=retain_loss, typ = "retain") + + loss = forget_loss + self.steps +=1 + + elif self.forget_loss_type == "NPO": + + forget_inputs = inputs["forget"] + forget_loss, forget_outputs = compute_dpo_loss( + model=model, + ref_model=self.ref_model, + win_inputs=None, + lose_inputs=forget_inputs, + beta=0.1, + ) + del forget_outputs + self.store_grads(model, loss=forget_loss, typ = "objective") + + retain_inputs = inputs["retain"] + retain_inputs = { + "input_ids": retain_inputs["input_ids"], + "attention_mask": retain_inputs["attention_mask"], + "labels": retain_inputs["labels"], + } + retain_loss = self.compute_retain_loss(model=model, retain_inputs=retain_inputs) + self.store_grads(model, loss=retain_loss, typ = "retain") + + loss = forget_loss + retain_loss + return (loss, forget_outputs) if return_outputs else loss \ No newline at end of file From 911e52a1cdd75e753e868b568257e5d2796af373 Mon Sep 17 00:00:00 2001 From: yuuee-www <969761826@qq.com> Date: Wed, 14 May 2025 17:34:16 +0800 Subject: [PATCH 4/4] Clean up formatting and update GRU README (results still running) --- community/methods/GRU/README.md | 24 ++++++++---------------- src/trainer/unlearn/gru.py | 5 ----- 2 files changed, 8 insertions(+), 21 deletions(-) diff --git a/community/methods/GRU/README.md b/community/methods/GRU/README.md index fcecf9e..0c23eaf 100644 --- a/community/methods/GRU/README.md +++ b/community/methods/GRU/README.md @@ -1,34 +1,26 @@ -# TITLE +# GRU - **Paper Title**: GRU: Mitigating the Trade-off Between Unlearning and Retention for Large Language Models - **Authors**: Yue Wang, Qizhou Wang, Feng Liu, Wei Huang, Yali Du, Xiaojiang Du, Bo Han - **Links**: [arXiv:2503.09117](https://arxiv.org/abs/2503.09117) -Provide a concise summary of your method details and its contributions. Please avoid using images to keep the repository size manageable. +This work proposes **Gradient Rectified Unlearning (GRU)**, a general framework for improving unlearning performance without sacrificing retention in large language models. GRU modifies the gradient update rule to remove the component of the unlearning gradient that conflicts with the retention gradient. # Setup -Please include the experimental setup such as +- **Hyperparameters & Search Space**: + - Gradient EMA smoothing factor \(\gamma \in \{0.8, 0.9, 0.95, \text{N/A}\}\) -- [ ] **Hyperparameters & Search Space:** Specify key hyperparameters, their search ranges, number of trials etc. -- [ ] **Computational Setup:** Mention the type and number of GPUs used. -- [ ] **DeepSpeed Configuration:** If any modifications were made to the default DeepSpeed config, specify them here. (You may include the config as a code block.) -- [ ] **Other Details:** Any additional setup details crucial for reproducing your method. +- **GPU Type**: NVIDIA A100 80GB +- **GPU Usage**: Current code supports **single GPU execution only**. Multi-GPU support is not yet implemented. -## Computational Setup - - -- **GPU Details**: NVIDIA A100 80GB -- **GPU Count**: The code for our method currently supports single GPU execution. We plan to enhance the codebase in the future to support multi-GPU configurations. - +- **DeepSpeed Configuration**: + GRU currently **does not support DeepSpeed** due to its reliance on fine-grained gradient manipulation. Please ensure DeepSpeed is disabled for all GRU experiments. # Results -To replicate your results, provide a `run.sh` script that contains all necessary commands to reproduce the final results. Ensure the script is well-documented. - -It would be appreciated if you can upload the final unlearned model(s) along with their `evals` folders to HuggingFace and provide the link(s) here. As the evaluations are updated, this would help us re-evaluate your model(s). # Citation diff --git a/src/trainer/unlearn/gru.py b/src/trainer/unlearn/gru.py index ab1ed56..9c1f237 100644 --- a/src/trainer/unlearn/gru.py +++ b/src/trainer/unlearn/gru.py @@ -17,7 +17,6 @@ def __init__(self, gamma_gru=0.8, forget_loss_type="GradAscent", *args, **kwarg self.gradient_accumulation_steps = kwargs["args"].gradient_accumulation_steps if self.ref_model is None and self.forget_loss_type == "NPO": self.ref_model = self._prepare_ref_model(self.model) - #self.ref_model = self.model.to(self.args.device) # Initialization of internal variables to store gradients and computational states self.dotp_retain = 0.0 @@ -129,9 +128,6 @@ def compute_total_gradient_dot_product(self, flattened_grads1, flattened_grads2, flattened_grads1 = flattened_grads1.to('cuda') flattened_grads2 = flattened_grads2.to('cuda') - # for ((name1, shape1), (name2, shape2)) in zip(structure_map1, structure_map2): - # assert name1 == name2 and shape1 == shape2, f"Gradient mismatch: {name1} vs {name2} or {shape1} vs {shape2}" - for ((name1, shape1), (name2, shape2)) in zip(structure_map, structure_map): assert shape1 == shape2, f"Gradient mismatch: {name1} vs {name2} or {shape1} vs {shape2}" @@ -179,7 +175,6 @@ def restore_gradients(self, model): def pipeline(self): if self.dotp_retain < 0: - #print("dotp_retain:",self.dotp_retain) self.flattened_gradient = self.orthogonal_component(self.flattened_gradient, self.flattened_retain) torch.cuda.empty_cache()