silvandeleemput / memcnn

PyTorch Framework for Developing Memory Efficient Deep Invertible Networks
MIT License
251 stars 26 forks source link

Mixed precision and InvertibleModuleWrapper #57

Closed ibro45 closed 4 years ago

ibro45 commented 4 years ago

Description

When using the the native Pytorch Automatic Mixed Precision:

What I Did

Minimal example:

import torch
from torch import nn
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler

import torchvision
from torchvision.models.resnet import resnet18, BasicBlock
import torchvision.transforms as transforms

import memcnn

# ------- Irrelevant -------------
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)
# --------------------------------

USE_AMP = True
MANUALLY_CAST_INPUT_TO_INV_BLOCK = False

class InvertibleBlock(nn.Module):
    def __init__(self, block, keep_input):
        """The input block should already be split across channels
        """
        super().__init__()

        invertible_module = memcnn.AdditiveCoupling(block)
        self.invertible_block = memcnn.InvertibleModuleWrapper(fn=invertible_module, 
                                                               keep_input=keep_input, 
                                                               keep_input_inverse=keep_input)
    def forward(self, x, inverse=False):
        if MANUALLY_CAST_INPUT_TO_INV_BLOCK:
            x = x.float()
        if inverse:
            return self.invertible_block.inverse(x)
        else:
            return self.invertible_block(x)

model = resnet18(num_classes=10)

# Replace with invertible blocks
model.layer1 = nn.Sequential(
    InvertibleBlock(BasicBlock(32, 32), keep_input=False),
    InvertibleBlock(BasicBlock(32, 32), keep_input=False)
)

model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
scaler = GradScaler(enabled=USE_AMP)

for epoch in range(2):

    running_loss = 0.0
    for i, data in enumerate(trainloader):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        # -------- AMP ----------
        with autocast(enabled=USE_AMP):
            outputs = model(inputs)
            loss = criterion(outputs, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        # -----------------------

        running_loss += loss.item()
        if i % 100 == 99:
            print('[%d, %5d] loss: %.6f' % (epoch + 1, i + 1, running_loss / 100))
            running_loss = 0.0

print('Finished Training')
silvandeleemput commented 4 years ago

Good find! Thanks for the detailed example, I'll have a look at it on Monday. It appears to be a straightforward fix.

silvandeleemput commented 4 years ago

@ibro45 This issue is a lot trickier than I initially anticipated. The problem with this issue is very similar to using torch.cuda.amp with checkpointing, see here: https://github.com/pytorch/pytorch/issues/37730

At the moment checkpointing and torch.cuda.amp don't work together nicely and since the memcnn.InvertibleModuleWrapper works in similar ways to checkpointing it suffers from similar issues. Sadly some of the proposed solutions in the abovementioned ticket do not work for the memcnn.InvertibleModuleWrapper since it is even a bit more complicated than checkpointing.

In the end, I was able to get your example code working by: 1) enforcing float32. I did this by adding @custom_fwd(cast_inputs=torch.float32) and @custom_bwd before the forward and backward methods in the InvertibleCheckpointFunction respectively. (I would recommend this for now) 2) enforcing float16 model weights. model.half() which converts all model weights to float16 and by ditching the GradScaler (doesn't like scaling float16 gradients and will complain). This will save even more memory, but I can't make any claims about the quality of the gradients or the results from the training process.

Both solutions are not really satisfactory at the moment. Here is a link to the commit implementing 1: https://github.com/silvandeleemput/memcnn/commit/881fd35df0d6bd6214810e3170107da884a98f49

Check test_amp.py for how I used this: https://github.com/silvandeleemput/memcnn/commit/881fd35df0d6bd6214810e3170107da884a98f49#diff-30ad4023a6eeedb230c16ee579c7a86b7c738fa5bd6ab34f939386d7969c2568

silvandeleemput commented 4 years ago

@ibro45 I have chosen to accept solution 1 as the current best "fix" for mixed-precision training. I am currently pushing a new release (1.5.0) for memcnn that includes the fix. As recommended by you I also included a note describing the limitations of mixed-precision training (float32 inputs) to the InvertibleModuleWrapper. I hope this solves your problem.

ibro45 commented 4 years ago

Thanks a lot for the fix and sorry for getting back to you earlier, got sidetracked.

Is the current solution not the best fix because it requires scaler.backward() inside of autocast or that's not the only reason? Also, is it necessary that the backward() is under autocast and, if so, do you know why? I'm currently running it with backward() outside of autocast in my code, there are no errors, but is it safe to do that?

silvandeleemput commented 3 years ago

@ibro45

Thanks a lot for the fix and sorry for getting back to you earlier, got sidetracked.

No problem.

Is the current solution not the best fix because it requires scaler.backward() inside of autocast or that's not the only reason? Also, is it necessary that the backward() is under autocast and, if so, do you know why?

I'm currently running it with backward() outside of autocast in my code, there are no errors, but is it safe to do that?

Yes, I think it should be.