microsoft / tutel

Tutel MoE: An Optimized Mixture-of-Experts Implementation
MIT License
694 stars 84 forks source link

Enable running without bias and update ffn instantiation #216

Closed vchiley closed 10 months ago

vchiley commented 10 months ago

Clone of https://github.com/microsoft/tutel/pull/210 but this leaves _num_global_experts as a registered buffer.

This PR

Notes: All tensor naming is still the same. If bias=True the generated state dict will be the same. All old state dicts can still be loaded.

This only makes layer init more conventional (along with using a more conventional reset_parameters) eg torch.nn.Linear layer. (Standard PyTorch attaches params in __init__ instead of in an auxiliary self.update(ctx) fn. After params are attached standard PyTorch calls reset_parameters.)

I just moved parameter init, I didn't change how init happens. The seeding is unchanged. The parameters are still initialized in an nn.Linear layer and copied over

vchiley commented 10 months ago

Here @ghostplant claims:

Original's expert parameter initialization seed was specially designed to ensure initial parameter values to be deterministic regardless of whether expert parameters are sharded or not. In order words, 2 GPU training 2 experts in total should have exact the same training loss compared with 4 GPU training 2 experts (i.e. each GPU training half of an expert).

I don't see how original init ever had this functionality.

Regardless, this change does not change how seeding happens. If the functionality previously existed, it still will.

The functionality @ghostplant claims is possible if every GPU had the same seed, then in reset_parameters we instantiate a tensor with ALL experts, and grab ONLY the subset required for the current GPU.

ghostplant commented 10 months ago

Thanks! Please wait for checkpoint compatibility validation.

ghostplant commented 10 months ago

Please apply or accept this change: https://github.com/vchiley/tutel_msr/pull/2. Then we can merge it. Thanks!

ghostplant commented 10 months ago

Thanks for the contribution!