pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
82.07k stars 22.05k forks source link

Init Optim State in Checkpointing Does not Apply to Stateless Optimizers #133415

Closed mvpatel2000 closed 1 day ago

mvpatel2000 commented 4 weeks ago

🐛 Describe the bug

The new API for checkpointing, e.g. set_optimizer_state_dict, calls _init_optim_state.

https://github.com/pytorch/pytorch/blob/2a4304329be0ee592af0f1d8d0dd9428ed82a0c6/torch/distributed/checkpoint/state_dict.py#L585-L605

This function will initialize optimizer states if they are not already initialized by doing a step with grads as zero.

However, this causes problems for stateless optimizers, e.g. SGD. For example, if I have a simple code snippet where I train a model, save it, and then load it, using the standard training loop structure:

for i, data in enumerate(train_loader, 0):
    # Get the inputs; data is a tuple of (inputs, labels)
    inputs, labels = data

    # Zero the parameter gradients
    optimizer.zero_grad()

    # Forward pass
    outputs = model(inputs)
    loss = loss_function(outputs, labels)

    # Backward pass and optimize
    loss.backward()
    optimizer.step()

    # Periodically checkpoint

# Load checkpoint   

# Resume training from ckpt
for i, data in enumerate(train_loader, 0):  
    # Get the inputs; data is a tuple of (inputs, labels)
    inputs, labels = data

    # Zero the parameter gradients
    optimizer.zero_grad()

    # Forward pass
    outputs = model(inputs)
    loss = loss_function(outputs, labels)

    # Backward pass and optimize
    loss.backward()
    optimizer.step()

    # Periodically checkpoint

this will fail as optimizer.zero_grad() is typically called before the forward pass. Accordingly, the resume from checkpoint call will fail for SGD as it will call _init_optim_states, which will raise an error. This is because state is not set (as SGD has no state) but grads are non-zero in the standard training loop set up.

Instead, I believe that the init call should check if the optimizer is a stateless optimizer.

Versions

PyTorch 2.4

cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

mvpatel2000 commented 4 weeks ago

@pytorchbot label "oncall: distributed"

mvpatel2000 commented 4 weeks ago

As an example of attempting to fix this https://github.com/mosaicml/composer/pull/3549/files#diff-08f27b570e4d38016a29ab784d0a1abba2828c744bbdb2a1dc0aa5c421813f2bR2445-R2452, we now explicitly clear gradients after a call to fit in Composer, a library which builds a Trainer on top of PyTorch. This seems unideal and is only required due to the new checkpointing APIs

mvpatel2000 commented 4 weeks ago

More generally, it seems checkpointing fails completely with torch.optim.SGD as typically checkpointing happens after a backward pass before gradients are cleared, resulting in an error

weifengpy commented 3 weeks ago

traiged since @fegin self-assigned

mvpatel2000 commented 2 weeks ago

@fegin any chance we can include in 2.4.1 as it is a regression?

mvpatel2000 commented 2 days ago

@fegin bumping -- any chance to address since this breaks checkpointing?