Skip to content

[JAX] Updated: unbalanced CP with THD format #1709

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

huanghua1994
Copy link
Collaborator

Description

This PR rebased PR 1565 to the main branch and fixed the bug in the original PR. After this PR, the non-balancing implementation of CP with THD format will be supported.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Support unbalanced context parallelism with THD format
  • Support non-causal mask
  • Add related unit tests

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@huanghua1994 huanghua1994 force-pushed the rewang/thd-ring-unbalanced-updated branch from 06a9470 to e24818c Compare April 22, 2025 17:48
@huanghua1994
Copy link
Collaborator Author

/te-ci jax L1

@huanghua1994 huanghua1994 force-pushed the rewang/thd-ring-unbalanced-updated branch 2 times, most recently from dd846af to 3dd1874 Compare April 25, 2025 16:37
@huanghua1994
Copy link
Collaborator Author

/te-ci jax L1

@huanghua1994 huanghua1994 force-pushed the rewang/thd-ring-unbalanced-updated branch 3 times, most recently from cd87949 to f3d92a5 Compare May 1, 2025 15:48
@KshitijLakhani KshitijLakhani self-requested a review May 2, 2025 21:45
zlsh80826 and others added 4 commits May 10, 2025 08:58
Signed-off-by: Reese Wang <rewang@nvidia.com>

fix up

Signed-off-by: Reese Wang <rewang@nvidia.com>

Add padding mask

Signed-off-by: Reese Wang <rewang@nvidia.com>

Rebase and fix

Signed-off-by: Hua Huang <huah@nvidia.com>
Signed-off-by: Hua Huang <huah@nvidia.com>
Signed-off-by: Hua Huang <huah@nvidia.com>
@huanghua1994 huanghua1994 force-pushed the rewang/thd-ring-unbalanced-updated branch from f3d92a5 to bdadabe Compare May 10, 2025 17:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants