Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed layers #1270

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft

Distributed layers #1270

wants to merge 6 commits into from

Conversation

angeloskath
Copy link
Member

Adds linear layers that allow training and inference of a model sharded across several devices. The main things added are

  • float16/bfloat16 reductions for MPI
  • AllToShardedLinear and its quantized sibling
  • ShardedToAllLinear and its quantized sibling

simply changing linear layers to the above results in a model that works out of the box with distributed inference and training.

I am starting it as a draft so that we can iterate a bit on the design. The negative aspects of the above design are that we have yet another linear layer to think about when implementing LoRA and friends or weird new quantizations for instance. Perhaps it would be better to make the above layers with an internal linear layer so model surgery that swaps linear layers would still work out of the box.

sl = cls(input_dims, output_dims, False, group)
# The multiplication with 1.0 forces a copy, perhaps change to
# something better when available.
sl.weight = linear_layer.weight[r * step : (r + 1) * step] * 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible the input buffer could be donated so we'd still hold on to the memory?

If so, maybe another option is to do sl.weight[ ... ] = ... that will force the copy since it's a slice update?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! That does sound better actually!

@awni
Copy link
Member

awni commented Jul 17, 2024

I kind of like this design. I like that it's all quite simple and easy to follow and we have a lot of control over how to shard the model (as in ml-explore/mlx-examples#890). We could possibly find a way to reduce the code needed for adding a new custom linear-like layer.. but the simplicity is nice, I wouldn't want to give that up.

@angeloskath angeloskath force-pushed the distributed-layers branch 2 times, most recently from 061d214 to b32ce2c Compare August 29, 2024 08:20
@angeloskath angeloskath force-pushed the distributed-layers branch 2 times, most recently from ab26116 to 3d431c0 Compare September 6, 2024 18:03
@awni awni mentioned this pull request Sep 16, 2024
@angeloskath angeloskath force-pushed the distributed-layers branch 5 times, most recently from 2298954 to 1697581 Compare November 5, 2024 19:35
@awni awni force-pushed the distributed-layers branch from 1697581 to 5921570 Compare January 14, 2025 21:10
@awni awni force-pushed the distributed-layers branch from 5921570 to 31ba022 Compare January 15, 2025 14:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants