huggingface / accelerate

🚀 A simple way to launch, train, and use PyTorch models on almost any device and distributed configuration, automatic mixed precision (including fp8), and easy-to-configure FSDP and DeepSpeed support
https://huggingface.co/docs/accelerate
Apache License 2.0
7.97k stars 970 forks source link

Why don't ML engineers use shampoo ?🧴 #3178

Open G-structure opened 1 month ago

G-structure commented 1 month ago

Hey,

I have been using Meta's Implementation of Distributed Shampoo and am seeing ~20% faster convergence of transformer based models compared to AdamW. Simo Ryu has done some nice investigations into the advantages of Shampoo.

I am looking to use Shampoo and Soap as an optimizer in accelerate but their current implementations introduce some breaking changes.

Focusing on Shampoo for now:

Distributed Shampoo disabled state_dict and load_state_dict in favor of a custom distributed_state_dict, load_distributed_state_dict. Both of which require the models named_parameters() to be passed in as args. More info as to why here

I have a hacky commit here to patch accelerate/optimizers. However I am still forced to bypass accelerate.save() and use dist_checkpoint.save_state_dict() directly since the optimizer in the state_dict needs to have access to the models named_parameters().

state_dict = {
    "model": model.state_dict(),
    "optim": optimizer.distributed_state_dict(key_to_param=model.named_parameters()),
}
dist_checkpoint.save_state_dict(
    state_dict=state_dict,
    storage_writer=dist_checkpoint.FileSystemWriter(CHECKPOINT_DIR),
)

You can see this here in my e2-tts training code. I am able to save the model weights but am not yet able to load them again when using accelerate. This is where I am lost currently.

Also since I don't have access to the named_parameters until accelerate.prepare_model() is called the shampoo optimizer needs to be defined in the model definition, which makes it awkward to switch between optimizers, see here

Ideally id be able to do something like this where I pass in the optimizer as I can with AdamW.

e2tts = E2TTS(
    cond_drop_prob=0.0,
    transformer = dict(
        dim = 512,
        depth = 2,
        heads = 6,
        skip_connect_type = 'concat'
    ),
    mel_spec_kwargs = dict(
        filter_length = 1024,
        hop_length = 256,
        win_length = 1024,
        n_mel_channels = 100,
        sampling_rate = 24000,
    ),
    frac_lengths_mask = (0.7, 0.9)
)

optimizer = DistributedShampoo(
    e2tts.parameters(),
    lr=7.5e-5,
    betas=(0.9, 0.999),
    epsilon=1e-12,
    weight_decay=1e-05,
    max_preconditioner_dim=8192,
    precondition_frequency=100,
    use_decoupled_weight_decay=False,
        grafting_config=AdamGraftingConfig(
            beta2=0.999,
            epsilon=1e-08,
        ),
    )

trainer = E2Trainer(
    e2tts,
    optimizer
)

trainer.train(train_dataset, epochs, batch_size, save_step=50)

ofc when I setup everything with torch ddp, instead of accelerate everything works as intended :/

What would be the best approach for accelerate to support these custom optimizers (ones not part of torch)? My plan currently is to write a ShampooPlugin along the lines of the DeepSpeedPlugin, but it would be nice if the shampoo optimizer could be detected automatically without having to change the accelerate config. I am willing to put in the work to solve this so more projects can benefit from using these new optimizers with accelerate.

Any guidance would be much appreciated. :)

bghira commented 1 month ago

that optim lacks other torch-specific expectations and like learning rate schedulers, some off-the-wall optimisers just don't work without modification. that distributed zero shampoo optim is a WIP technical prototype and not meant to be used in production, for example.

the SOAP one worked as expected for me. it just needs Closure input on its step. see here: https://github.com/bghira/SimpleTuner/blob/main/helpers/training/optimizers/soap/__init__.py

bghira commented 1 month ago

other problems of the original optim implementations linked is that they are not functioning with torch.compile and retain very slow performance (exaggerated in SOAP) and high memory overhead (also exaggerated in SOAP) even with ZeRO offload

github-actions[bot] commented 5 days ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.