TransformerLensOrg / TransformerLens

A library for mechanistic interpretability of GPT-style language models
https://transformerlensorg.github.io/TransformerLens/
MIT License
1.46k stars 287 forks source link

Add support for model parallelism #105

Closed neelnanda-io closed 1 year ago

neelnanda-io commented 1 year ago

Add support for having a model with layers split across several GPUs.

Make sure the layers (and its HookPoints) know what device they're on, so that hooks can ensure that they aren't needlessly moving information between GPUs. ActivationCache is a dictionary and should work by default.

The MVP would be doing 2 GPUs: putting the embed + first half of layers on GPU 1 and the second half + unembed on GPU 2. This is probably the most that's needed to support eg NeoX?

I'm not sure of the most elegant way of doing this, or how to do this without making the code really messy. I lean towards either adding a method which edits the model to move layers between devices, or making a separate ParallelHookedTransformer class

anshradh commented 1 year ago

To be clear: you're looking for support of Pipeline Parallelism, rather than e.g. Tensor Parallelism?

neelnanda-io commented 1 year ago

Yeah, pipeline (ie each layer on a single GPU). It seems very important that an activation is on a single GPU, so hooks can kinda function as before, just ensuring they act on the correct device. Activations split across GPUs sounds horrifying. I think this is worth taking a decent performance hit for.

On Mon, 19 Dec 2022, 4:42 pm Ansh Radhakrishnan, @.***> wrote:

To be clear: you're looking for support of Pipeline Parallelism, rather than e.g. Tensor Parallelism?

— Reply to this email directly, view it on GitHub https://github.com/neelnanda-io/TransformerLens/issues/105#issuecomment-1357939344, or unsubscribe https://github.com/notifications/unsubscribe-auth/ASRPNKKKY6TTD4KBCZSKBT3WOCF5TANCNFSM6AAAAAATDIXQE4 . You are receiving this because you authored the thread.Message ID: @.***>

anshradh commented 1 year ago

Yeah agreed tensor parallelism seems pretty cursed.

I'd be interested in taking this on, my proposed implementation would likely involve the user passing the number of devices they want to load the model on and then figuring out the distribution of layers from there (e.g num_devices = 4 for gpt-2-small -> embed and layers 0-2 on gpu 0, layers 3-5 on gpu 1, layers 6-8 on gpu 2, layers 9-11 and unembed on gpu 3.)

Seems like it would be also nice to support something like gpipe to speed up training (basically just split up minibatches further and do gradient accumulation, should be simple enough).

neelnanda-io commented 1 year ago

That API seems pretty reasonable to me! Supporting just a forward pass is probably fine, and supporting efficient training seems like probably overkill, until and unless anyone wants to start using this feature for something hardcore and large scale. TransformerLens is not designed to train 20B models!

On Tue, 20 Dec 2022 at 18:12, Ansh Radhakrishnan @.***> wrote:

Yeah agreed tensor parallelism seems pretty cursed.

I'd be interested in taking this on, my proposed implementation would likely involve the user passing the number of devices they want to load the model on and then figuring out the distribution of layers from there (e.g num_devices = 4 for gpt-2-small -> embed and layers 0-2 on gpu 0, layers 3-5 on gpu 1, layers 6-8 on gpu 2, layers 9-11 and unembed on gpu 3.)

Seems like it would be also nice to support something like gpipe to speed up training (basically just split up minibatches further and do gradient accumulation, should be simple enough).

— Reply to this email directly, view it on GitHub https://github.com/neelnanda-io/TransformerLens/issues/105#issuecomment-1359943099, or unsubscribe https://github.com/notifications/unsubscribe-auth/ASRPNKOFOUBT4OUUGYR5CD3WOHZJRANCNFSM6AAAAAATDIXQE4 . You are receiving this because you authored the thread.Message ID: @.***>

jbloomAus commented 1 year ago

Tasks required to complete this issue:

May add more if we decide the model parallelism requires other enhancements not yet obvious.

jbloomAus commented 1 year ago

Solved due to #250 solving that last issue.