-
Notifications
You must be signed in to change notification settings - Fork 240
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add JSD Loss for Distillation #425
base: main
Are you sure you want to change the base?
Conversation
## Summary Made #417 from the main repo. Thanks to the nice suggestions from @Tcc0403 and @hongpeng-guo. This PR is the s first split from #408, focusing solely on introducing the Knowledge Distillation base class. As a result, this PR does not include any tests at the moment. #### Code Changes 1. Refactor `beta` into two weights: `weight_hard_loss` and `weight_soft_loss`, as coefficients between `hard_loss` and `soft_loss`. @Tcc0403 also pointed out that we could use `torch.lerp` if applicable. 2. Pass `teacher_logits` and `student_logits` directly to the divergence loss function. This avoids redundant computations of converting logits to log probabilities and then reverting them to raw logits. However note that we are not reusing the `student_log_probs` value calculated during `ce_loss` in distillation base. 1. Remove the unnecessary `get_batch_logps` in `test/utils.py`. 3. Modify `chunking` dimensions from `B` to `B * T`. Thanks to @hongpeng-guo's great advice. 1. Fix the loss calculation to use per-token values instead of averaging across the sequence length dimension. 4. Normalize the `distillation_loss` using `(full_target != ignore_index).sum()`. #### TODO 1. [X] Although a slightly slowdown is reasonable, we need to investigate why this PR's implementation is **significantly slower** compared to the naive approach. Thanks to @Tcc0403 's clarification. The issue arises because we are not properly configuring the `chunk_size` for the `B * T` dimension, which is extremely large (a few thousand). The previous default of 1 results in an excessive number of chunks. In contrast, this problem does not occur with the preference loss, as chunking is performed on the `B` dimension. This produces fewer than 10 chunks, which is efficient and works as expected. In conclusion, I set `chunk_size` to `1024` works pretty well in new benchmark results as shown in #425 2. [ ] #417 (comment) #### Knowledge Distillation Knowledge Distillation (KD; [Hinton et al. 2015](https://arxiv.org/abs/1503.02531), [Gou et al. 2020](https://arxiv.org/abs/2006.05525)) is a straightforward way to build a smaller, cheaper model (“student model”) to speed up inference by transferring skills from a pre-trained expensive model (“teacher model”) into the student. In knowledge distillation, a student model is trained to replicate the outputs of a teacher model using a distillation loss. Neural networks typically include a softmax layer; for instance, a large language model produces a probability distribution over tokens. Let `z_t` and `z_s` represent the logits before the softmax layer for the teacher and student models, respectively. The distillation loss reduces the discrepancy between the two softmax outputs at a high temperature `T`. When ground truth labels `y` are available, this approach can be combined with a supervised learning objective, such as cross-entropy, to compare the student’s outputs with the ground truth. The combined loss function is defined as: ```math \mathcal{L}_{\text{knowledge distillation}} = \mathcal{w}_{\text{soft}} \cdot \mathcal{L}_{\text{distill}}(\mathbf{z_t}, \mathbf{z_s}, T) + \mathcal{w}_{\text{hard}} \cdot \mathcal{L}_{\text{cross entropy}}(\mathbf{y}, \mathbf{z_s}), ``` Here, we directly pass in `logits` rather than `logpbs`. @Tcc0403 #### Shared `DistillationBase` To support various distillation learning objectives, this PR aims to add a `LigerFusedLinearDistillationBase` which is basically same as propose by @hongpeng-guo within this discussion #371 (comment). Thank you @hongpeng-guo for thinking through this. ## Testing Done I'll post JSD tests and benchmarks results in next PR: #425 - Hardware Type: L40S - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Signed-off-by: Austin Liu <austin362667@gmail.com> Co-authored-by: shivam15s <shivam15800@gmail.com>
d052bab
to
1d3b064
Compare
Signed-off-by: Austin Liu <austin362667@gmail.com>
Signed-off-by: Austin Liu <austin362667@gmail.com>
Signed-off-by: Austin Liu <austin362667@gmail.com>
Signed-off-by: Austin Liu <austin362667@gmail.com>
…eduction Signed-off-by: Austin Liu <austin362667@gmail.com>
1d3b064
to
87a7f8c
Compare
Signed-off-by: Austin Liu <austin362667@gmail.com>
cc @Mecoli1219 |
student_logits = student_logits / temperature | ||
teacher_logits = teacher_logits / temperature | ||
# Convert to probabilities | ||
student_probs = F.softmax(student_logits, dim=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://pytorch.org/docs/stable/generated/torch.nn.functional.log_softmax.html
It's recommended to do logsoftmax for speed and numerical stability instead of two separate operations according to documentation.
student_probs = F.softmax(student_logits, dim=-1) | ||
teacher_probs = F.softmax(teacher_logits, dim=-1) | ||
|
||
log_mean_probs = torch.log((student_probs + teacher_probs) / 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps we can add beta support for general jsd as #278 did.
Signed-off-by: Austin Liu <austin362667@gmail.com>
We can refactor distillation base so that it aligns with the code structure of preference base in follow-up PRs. |
Summary
Caution
This PR depends on #417. Do not merge until #417 (later #432) is merged.
This is a pure torch compiled, chunked fused linear JSD Loss, aiming for knowledge distillation.
Jensen-Shannon Divergence Loss
This PR implements Jensen-Shannon Divergence (JSD) loss as the soft learning objective in a distillation setting (teacher & student). This component can be replaced with other losses (e.g., KL divergence) as
distillation_loss_fn
.JSD is defined as the average of the KL divergences between each distribution and the mean distribution:
Here,
P
andQ
are the two probability distributions, andM
is their average.Testing Done
Below figures are benchmark results with different
chunk_size
, which also significantly affects performance.Hint:
User can tune their
chunk_size
as suggested by the liger paper for the moment:Memory
chunk_size
= 1chunk_size
= 1024Speed (Elapsed Time)
chunk_size
= 1chunk_size
= 1024make test
to ensure correctnessmake checkstyle
to ensure code stylemake test-convergence
to ensure convergence