Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JSD Loss for Distillation #425

Open
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

austin362667
Copy link
Collaborator

@austin362667 austin362667 commented Dec 4, 2024

Summary

Caution

This PR depends on #417. Do not merge until #417 (later #432) is merged.

This is a pure torch compiled, chunked fused linear JSD Loss, aiming for knowledge distillation.

Jensen-Shannon Divergence Loss

This PR implements Jensen-Shannon Divergence (JSD) loss as the soft learning objective in a distillation setting (teacher & student). This component can be replaced with other losses (e.g., KL divergence) as distillation_loss_fn.

JSD is defined as the average of the KL divergences between each distribution and the mean distribution:

$$\text{JSD}(P || Q) = \frac{1}{2} \text{KL}(P || M) + \frac{1}{2} \text{KL}(Q || M), \quad \text{where } M = \frac{1}{2}(P + Q)$$

Here, Pand Q are the two probability distributions, and M is their average.

Testing Done

Below figures are benchmark results with different chunk_size, which also significantly affects performance.

Hint:

User can tune their chunk_size as suggested by the liger paper for the moment:

$$2^{\lceil \log_2 \lceil \frac{BT}{V/H} \rceil \rceil}$$

Memory

  1. chunk_size = 1

    distill_jsd_loss_memory_chunk_size_1

  2. chunk_size = 1024

    distill_jsd_loss_memory_chunk_size_1024

Speed (Elapsed Time)

  1. chunk_size = 1

    distill_jsd_loss_speed_chunk_size_1

  2. chunk_size = 1024

    distill_jsd_loss_speed_chunk_size_1024

  • Hardware Type: NVIDIA H100 80GB HBM3 (SXM5)
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

ByronHsu pushed a commit that referenced this pull request Dec 9, 2024
## Summary

Made #417 from the main
repo.

Thanks to the nice suggestions from @Tcc0403 and @hongpeng-guo. This PR
is the s first split from
#408, focusing solely on
introducing the Knowledge Distillation base class. As a result, this PR
does not include any tests at the moment.

#### Code Changes

1. Refactor `beta` into two weights: `weight_hard_loss` and
`weight_soft_loss`, as coefficients between `hard_loss` and `soft_loss`.
@Tcc0403 also pointed out that we could use `torch.lerp` if applicable.

2. Pass `teacher_logits` and `student_logits` directly to the divergence
loss function. This avoids redundant computations of converting logits
to log probabilities and then reverting them to raw logits. However note
that we are not reusing the `student_log_probs` value calculated during
`ce_loss` in distillation base.

    1. Remove the unnecessary `get_batch_logps` in `test/utils.py`.

3. Modify `chunking` dimensions from `B` to `B * T`. Thanks to
@hongpeng-guo's great advice.
1. Fix the loss calculation to use per-token values instead of averaging
across the sequence length dimension.

4. Normalize the `distillation_loss` using `(full_target !=
ignore_index).sum()`.

#### TODO  

1. [X] Although a slightly slowdown is reasonable, we need to
investigate why this PR's implementation is **significantly slower**
compared to the naive approach. Thanks to @Tcc0403 's clarification.
    
The issue arises because we are not properly configuring the
`chunk_size` for the `B * T` dimension, which is extremely large (a few
thousand). The previous default of 1 results in an excessive number of
chunks.

In contrast, this problem does not occur with the preference loss, as
chunking is performed on the `B` dimension. This produces fewer than 10
chunks, which is efficient and works as expected.

In conclusion, I set `chunk_size` to `1024` works pretty well in new
benchmark results as shown in
#425

2. [ ]
#417 (comment)

#### Knowledge Distillation

Knowledge Distillation (KD; [Hinton et al.
2015](https://arxiv.org/abs/1503.02531), [Gou et al.
2020](https://arxiv.org/abs/2006.05525)) is a straightforward way to
build a smaller, cheaper model (“student model”) to speed up inference
by transferring skills from a pre-trained expensive model (“teacher
model”) into the student.

In knowledge distillation, a student model is trained to replicate the
outputs of a teacher model using a distillation loss. Neural networks
typically include a softmax layer; for instance, a large language model
produces a probability distribution over tokens. Let `z_t` and `z_s`
represent the logits before the softmax layer for the teacher and
student models, respectively. The distillation loss reduces the
discrepancy between the two softmax outputs at a high temperature `T`.
When ground truth labels `y` are available, this approach can be
combined with a supervised learning objective, such as cross-entropy, to
compare the student’s outputs with the ground truth.

The combined loss function is defined as:

```math
\mathcal{L}_{\text{knowledge distillation}} = \mathcal{w}_{\text{soft}} \cdot \mathcal{L}_{\text{distill}}(\mathbf{z_t}, \mathbf{z_s}, T) + \mathcal{w}_{\text{hard}} \cdot \mathcal{L}_{\text{cross entropy}}(\mathbf{y}, \mathbf{z_s}),
``` 

Here,  we directly pass in `logits` rather than `logpbs`. @Tcc0403 

#### Shared `DistillationBase`

To support various distillation learning objectives, this PR aims to add
a `LigerFusedLinearDistillationBase` which is basically same as propose
by @hongpeng-guo within this discussion
#371 (comment).
Thank you @hongpeng-guo for thinking through this.

## Testing Done

I'll post JSD tests and benchmarks results in next PR:
#425

- Hardware Type: L40S
- [ ] run `make test` to ensure correctness
- [ ] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence

---------

Signed-off-by: Austin Liu <austin362667@gmail.com>
Co-authored-by: shivam15s <shivam15800@gmail.com>
@shivam15s shivam15s force-pushed the austin362667/chunked_compiled_jsd_loss branch from d052bab to 1d3b064 Compare December 17, 2024 02:18
Signed-off-by: Austin Liu <austin362667@gmail.com>
Signed-off-by: Austin Liu <austin362667@gmail.com>
Signed-off-by: Austin Liu <austin362667@gmail.com>
Signed-off-by: Austin Liu <austin362667@gmail.com>
…eduction

Signed-off-by: Austin Liu <austin362667@gmail.com>
@austin362667 austin362667 force-pushed the austin362667/chunked_compiled_jsd_loss branch from 1d3b064 to 87a7f8c Compare January 7, 2025 10:02
Signed-off-by: Austin Liu <austin362667@gmail.com>
@austin362667 austin362667 marked this pull request as ready for review January 7, 2025 10:24
@austin362667
Copy link
Collaborator Author

cc @Mecoli1219

student_logits = student_logits / temperature
teacher_logits = teacher_logits / temperature
# Convert to probabilities
student_probs = F.softmax(student_logits, dim=-1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://pytorch.org/docs/stable/generated/torch.nn.functional.log_softmax.html

It's recommended to do logsoftmax for speed and numerical stability instead of two separate operations according to documentation.

Comment on lines 19 to 22
student_probs = F.softmax(student_logits, dim=-1)
teacher_probs = F.softmax(teacher_logits, dim=-1)

log_mean_probs = torch.log((student_probs + teacher_probs) / 2)
Copy link
Collaborator

@Tcc0403 Tcc0403 Jan 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we can add beta support for general jsd as #278 did.

Signed-off-by: Austin Liu <austin362667@gmail.com>
Signed-off-by: Austin Liu <austin362667@gmail.com>
@austin362667
Copy link
Collaborator Author

We can refactor distillation base so that it aligns with the code structure of preference base in follow-up PRs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants