microsoft / mup

maximal update parametrization (µP)
https://arxiv.org/abs/2203.03466
MIT License
1.25k stars 90 forks source link

Is this compatible with DeepSpeed / ZeRO? #6

Closed StellaAthena closed 1 year ago

StellaAthena commented 2 years ago

I train large language models using DeepSpeed's ZeRO optimizer. Does this library support ZeRO?

edwardjhu commented 2 years ago

Hi Stella,

Thanks for bringing our attention to the compatibility with DeepSpeed!

In principle, yes. Our muAdam is just a wrapper that scales the learning rates for different parameter groups differently before passing them on to Adam.

See https://github.com/microsoft/mup/blob/c9d67001c47ae254ea4b7e26146ffd059520b6ba/mup/optim.py#L38

If I understand correctly, ZeRO is a technique that makes the underlying optimizer more efficient without changing the math. If you could pass the parameter groups created by muAdam to ZeRO Adam, it should work as expected.

Please let us know if our understanding of ZeRO is correct. We'll also look into what the best way to integrate with DeepSpeed is, given its popularity among folks who train large models.

zhuzilin commented 2 years ago

@StellaAthena @edwardjhu Our team is also trying to migrate muP to gpt-neox. It seems to me that MuAdam works fine with zero optimizer, as all it does is adjusting the learning rate. Figure below is the "coord check" test for a 12 layer model (basically the small.yml config in gpt-neox), width from 12 to 768: image As you can see, mup works well with gpt-neox!

I think the major barrier to use mup is that we could not use the MuReadout in mup as the output layer. We need to have a model-parallel version. And applying the coord check is hard on large model, because deepspeed will somehow occupy certain gpu memory even after the model is garbaged collected and current coord check impl need to initialized multiple model squentially, which will lead to OOM for multiple large model.

If you have interest in adapting mup to gpt-neox, maybe we can work together :)

StellaAthena commented 2 years ago

@zhuzilin Yeah, DeepSpeed’s garbage collection is, well, garbage. We actually have an auxiliary tool tools/kill_all.sh that kills all DeepSpeed processes across all connected machines because of how common it is to leave space allocated on GPUs when DS crashes.

What GPUs are you running tests on? If you’re having issues with having enough GPUs I may be able to help. DM me on Discord to work out the details :)

thegregyang commented 2 years ago

@zhuzilin So MuReadout essentially just does the following

    def forward(self, x):
        return super().forward(
            self.output_mult * x /  self.weight.infshape.width_mult())

If for some reason you can't use MuReadout as is, you can manually write out this forward, assuming you have already called set_base_shapes. For example, in your model forward function,

output = x @ self.output_weights.T * self.output_mult / self.output_weights.infshape.width_mult() + self.output_bias

I'm not sure what you need exactly for model parallelism, but hopefully this helps. But if you do figure out how to do model parallelism, we would like to hear how to incorporate your insight into mup to make it work right out the box.

StellaAthena commented 2 years ago

@zhuzilin I've set aside time to add support for muP to GPT-NeoX this week and would love to check out your code. Where can I find it? Perhaps you can open a PR detailing what has been done and what still needs to be done?

thegregyang commented 2 years ago

@zhuzilin @StellaAthena How is Deepspeed integration going? We can connect you with members of the Deepspeed team if necessary.