mosaicml / composer

Supercharge Your Model Training
http://docs.mosaicml.com
Apache License 2.0
5.16k stars 419 forks source link

FSDP Wrapping Alters Optimizer's Parameter Tracking Behavior #3493

Closed DavidBert closed 3 months ago

DavidBert commented 3 months ago

Hi! There appears to be an inconsistency in the behavior of the optimizer before and after wrapping with Fully Sharded Data Parallel (FSDP).

When FSDP wraps the optimizer, it seems to modify the set of parameters being tracked by the optimizer. Specifically, after wrapping, all model parameters are tracked by the optimizer.

The optimizer should maintain its original parameter tracking behavior after FSDP wrapping. For example, if the optimizer was initially tracking only parameters from specific submodules, it should continue to do so after FSDP wrapping.

This behavior change could lead to unexpected results, especially in cases where users have intentionally set up their optimizers to track only specific parts of the model. https://github.com/mosaicml/composer/blob/55f0b7d1880caaf218a56212e86f174bbc463d12/composer/distributed/dist_strategy.py#L719C1-L725C51

mvpatel2000 commented 3 months ago

Hm... yes I think it might be not properly saving the subset of parameters if the group is not all params.

Code to save: https://github.com/mosaicml/composer/blob/55f0b7d1880caaf218a56212e86f174bbc463d12/composer/distributed/dist_strategy.py#L265-L295

@sashaDoubov what do you think? iirc you added this bit

@DavidBert do you have a repro we can use here?

DavidBert commented 3 months ago

Thanks for the quick answer! This code should demonstrate the undesired behavior:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# optimizer
from torch.optim import Adam

from composer import Trainer
from composer.models import ComposerClassifier
from composer.utils import dist
import copy

class Model(nn.Module):
    """Toy convolutional neural network architecture in pytorch for MNIST."""

    def __init__(self, num_classes: int = 10):
        super().__init__()

        self.num_classes = num_classes

        self.conv1 = nn.Conv2d(1, 16, (3, 3), padding=0)
        self.conv2 = nn.Conv2d(16, 32, (3, 3), padding=0)
        self.bn = nn.BatchNorm2d(32)
        self.fc1 = nn.Linear(32 * 16, 32)
        self.fc2 = nn.Linear(32, num_classes)

    def forward(self, x):
        out = self.conv1(x)
        out = F.relu(out)
        out = self.conv2(out)
        out = self.bn(out)
        out = F.relu(out)
        out = F.adaptive_avg_pool2d(out, (4, 4))
        out = torch.flatten(out, 1, -1)
        out = self.fc1(out)
        out = F.relu(out)
        return self.fc2(out)

transform = transforms.Compose([transforms.ToTensor()])
dataset = datasets.MNIST("data", train=True, download=True, transform=transform)
sampler = dist.get_sampler(dataset, shuffle=True)

dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)

optimizer = Adam(Model().conv1.parameters(), lr=1e-3)
unwrapped_optimizer = copy.deepcopy(optimizer)

trainer = Trainer(
    model=ComposerClassifier(module=Model(), num_classes=10),
    train_dataloader=dataloader,
    max_duration="2ep",
    optimizers=optimizer,
    fsdp_config={"sharding_strategy":"_HYBRID_SHARD_ZERO2",
                    "mixed_precision":"PURE",
                    "backward_prefetch":"BACKWARD_PRE"},
)

with trainer.state.model.module.summon_full_params(trainer.state.model.module):
    nb_parameters_before_fsdp = len(unwrapped_optimizer.param_groups[0]["params"])
    nb_parameters_after_fsdp = len(trainer.state.optimizers[0].param_groups[0]["params"])
    assert nb_parameters_before_fsdp == nb_parameters_after_fsdp, f"expected {nb_parameters_before_fsdp} but got {nb_parameters_after_fsdp}"
mvpatel2000 commented 3 months ago

CC: @eracah can you take a look?

eracah commented 3 months ago

Good find, @DavidBert ! It looks like here we re-init the optimizer with all parameters even if you created it with only a subset of params (for optimizers with 1 param_group). We'll file a bug to create using just the parameters from the original optimizer. For now to unblock yourself you could make a param_group with the Model().conv1.parameters() and another param_group with the rest of the params using the Optimizer.add_param_group function. This will create >1 param group, which will then guarantee the optimizer is recreated with the correct setup. If all you want to do is freeze the other parameters you can use the optimizer as normal, but set requires_grad=False for all parameters. That might be simpler than creating an optimizer with a subset of the parameters

Lmk if that helps unblock you! cc: @sashaDoubov @mvpatel2000

DavidBert commented 3 months ago

Thanks @eracah! The requires_grad=False is indeed the easiest fix sor far for my specific problem. I tested it and it works as expected. Thanks for the help and all your work, Composer is a very cool library!

eracah commented 3 months ago

No problem! Glad we could unblock you, @DavidBert !

mvpatel2000 commented 3 months ago

Should now be fixed in general

DavidBert commented 3 months ago

Thank you guys!