TorchJD / torchjd

Library for Jacobian descent with PyTorch. It enables optimization of neural networks with multiple losses (e.g. multi-task learning).
https://torchjd.org
MIT License
151 stars 0 forks source link

Use with AMP? #135

Closed stevendbrown closed 1 month ago

stevendbrown commented 1 month ago

I'd like to use UPGrad/torchjd.backward with torch.autocast('cuda') but it's not clear to me how & when I should apply scaling in or around the aggregations. I have no issues with FP32-only mode, except that I have cut my batch sizes in half. Have you thought about how torchjd.backward could be compatible with AMP, or have any suggestions as to where I might focus my efforts to sort it out?

ValerianRey commented 1 month ago

Disclaimer: I have no experience at all with AMP.

Looking at the documentation (https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html), I'm assuming that you want to do something along the lines of this example:

# Constructs a ``scaler`` once, at the beginning of the convergence run, using default arguments.
# If your network fails to converge with default ``GradScaler`` arguments, please file an issue.
# The same ``GradScaler`` instance should be used for the entire convergence run.
# If you perform multiple convergence runs in the same script, each run should use
# a dedicated fresh ``GradScaler`` instance. ``GradScaler`` instances are lightweight.
scaler = torch.cuda.amp.GradScaler()

for epoch in range(0): # 0 epochs, this section is for illustration only
    for input, target in zip(data, targets):
        with torch.autocast(device_type=device, dtype=torch.float16):
            output = net(input)
            loss = loss_fn(output, target)

        # Scales loss. Calls ``backward()`` on scaled loss to create scaled gradients.
        scaler.scale(loss).backward()

        # ``scaler.step()`` first unscales the gradients of the optimizer's assigned parameters.
        # If these gradients do not contain ``inf``s or ``NaN``s, optimizer.step() is then called,
        # otherwise, optimizer.step() is skipped.
        scaler.step(opt)

        # Updates the scale for next iteration.
        scaler.update()

        opt.zero_grad() # set_to_none=True here can modestly improve performance

but with a call to torchjd.backward(...) replacing the call to scaler.scale(loss).backward(), and you're not sure exactly where to place the calls to the scaler. Am I correct?

For me, scaler.scale(loss) with simply scale the loss such that the corresponding gradient is not shrinked to 0 by lack of precision. I think in the case of Jacobian descent you could try to scale all of your losses with a single scaler object?

Then you should be able to make the call to torchjd.backward(...) with the scaled losses as parameter. Lastly, since torchjd.backward will also fill the .grad fields of your tensors, the step scaler.step(opt) should also work, and same for scaler.update().

What I'm not sure though, I how torchjd would behave with mixed precision.

Could you provide me with a code example of what you want to do?

stevendbrown commented 1 month ago

I built a toy model with CIFAR-100 and running the backward pass using torchjd.backward between GradScaler on the losses and GradScaler on the optimizer worked fine there, so the issue must be in my own messier model definition. Will reopen if I find whatever corner case is causing it. Thanks for making this code available! ✌️