pytorch / captum

Model interpretability and understanding for PyTorch
https://captum.ai
BSD 3-Clause "New" or "Revised" License
4.69k stars 475 forks source link

Massive VRAM Usage for Feature Ablation and Shapley Value Sampling #646

Open ndalton12 opened 3 years ago

ndalton12 commented 3 years ago

❓ Questions and Help

Hello, I have just set up to use captum to analyze a resnet50 model (taken directly from torchvision). Captum works perfectly as expected for most of the attribution methods. However, when I tried to used feature ablation and shapley value sampling, I get out of memory errors despite using pertubation_per_eval=1 (the minimum value). Further, I tried to run the same code, but using data parallel with 3 gpus instead, but the program still OOM.

For reference, I am using RTX 8000 gpus with 48 GB of VRAM each. Using feature ablation, for example, *the program almost swallowed the entire (roughly, minus a few other programs in the background) 483=144 GB of VRAM!** Is this expected? Some of the data parallel load balancing was not quite great however, as one GPU only had about 30 GB out of 48 used before one of them OOM. I have a total of 4x RTX 8000 to try, but one is in use currently. Also for reference, the resnet50 model should only take up maybe 5 GB of vram maximum during training. The tensor size is standard imagenet (Nx3x224x224).

Here is a code snippet:

    print("Getting feature ablation...")
    fig = captum_viz(prep_img, model, None, default_cmap, FeatureAblation, target=1,
                     perturbations_per_eval=1, show_progress=True)
    fig.savefig(file_name_to_export + '_feature_ablation.png')
def captum_viz(prep_img, model, background, default_cmap, method, target=1, **kwargs):
    single_img = prep_img.clone().cuda()
    class_instance = method(model)
    # error is on the instance.attribute line, errors with or without background
    if background is not None:
        background = background.clone()
        attributions_ig = class_instance.attribute(single_img, baselines=background, target=target, **kwargs)
    else:
        attributions_ig = class_instance.attribute(single_img, target=target, **kwargs)

...
vivekmig commented 3 years ago

Hi @ndalton12 , this is definitely not expected, it seems like there is some memory leak causing this issue.

I attempted to reproduce this with the latest versions of Captum / PyTorch by loading a ResNet50 pretrained model from torchvision and using a 3x224x224 random tensor with a Colab Notebook using a GPU (16 GB memory), but this seems to work fine. I used perturbations_per_eval = 500, since the total evaluations of 3x224x224 = 150528 forward passes would take substantially longer with perturbations_per_eval = 1.

What versions of Captum / PyTorch are you using? Also, if you are able to reproduce the issue with a Colab notebook GPU and can share the notebook, that would be very helpful for debugging further.

Alternatively, if you are able to install Captum from source, you can try adding the following function to print some memory debugging info and calling it from here in the inner loop of Feature Ablation (or after every k iterations).

def debug_memory():
    import collections, gc, resource, torch
    print(torch.cuda.memory_summary()) # Print for multiple devices if necessary
    tensors = collections.Counter((str(o.device), o.dtype, tuple(o.shape))
                                  for o in gc.get_objects()
                                  if torch.is_tensor(o))
    for line in tensors.items():
        print('{}\t{}'.format(*line))

These results would be helpful to debug further.

ndalton12 commented 3 years ago

Hello, yes it certainly looks like a memory leak.

I was not able to reproduce the error with colab originally, but I have now narrowed down the problem. For whatever reason, using the model (specifically the nn.Module) generated by wrapping it in a pytorch-lightning wrapper leads to this leak. The fix for this was to create a new resnet50 model fresh from torchvision and copy in the state dict from the pytorch-lightning wrapped version:

    import torchvision
    model2 = torchvision.models.resnet50()
    num_ftrs = model2.fc.in_features
    model2.fc = torch.nn.Linear(num_ftrs, 2)
    model2.classifier = model2.fc
    model2.load_state_dict(model.model.state_dict())
    model2.cuda()

Now, when I use model2 instead of model (the pytorch-lightning module which subclasses nn.Module, among other things), there is no memory leak. Weirdly enough, even when I use the model.model (the nn.Module held inside the pytorch-lightning module), the memory leak is still present. Parameter freezing (setting param.requires_grad=False) doesn't seem to be the difference maker - it's the only difference between the model2 and model.model that I can think of. The memory leak seems to be related to batch norm since that's what shows up in the stack trace. I am not sure if this is a pytorch-lightning issue or something else exactly.

Also, on a side note, the show progress feature does not seem to work. But this fix does allow me to run these codes without error, so feel free to close the issue if you see fit. But this is more of a workaround than a fundamental fix.

ndalton12 commented 3 years ago

Is it also expected that memory usage increases non-linearly with background image batch size? I can now provide a background image(s) of batch size=1, but a batch size=64 means that the program starts asking for 90+ GB of VRAM, even with the fix above. Normally, this batch size of 64 should fit no problem, not even close to that amount.

vivekmig commented 3 years ago

Interesting, thanks for the detailed information! It does sound like it could be related to PyTorch lightning since the original module doesn't result in the issue, but would need a full example notebook to reproduce and investigate further.

The show progress feature was recently added in #630 (and our website API docs update immediately based on master), so it's not in the release (pip / conda) builds yet. It will be available in the next release, or alternatively you can install Captum from source to try it out now.

The memory usage increasing non-linearly is also not expected. Are you able to reproduce this in Colab with just the original torchvision model? Also, just to confirm, for this, are you using both an input and baseline batch size of 64 (with perturbations per eval = 1) or is the input batch size still 1? For feature ablation, we only support baselines corresponding to each input example (or a single baseline example for all inputs) and not a larger number of baseline examples than input examples.

ndalton12 commented 3 years ago

Hi @vivekmig, I have managed to create a somewhat reproducible example: https://colab.research.google.com/drive/165Zj7wdiFaawbmzGROYqf3__UqY_-DgP?usp=sharing. The issue does not seem to be related to PyTorch Lightning since I can re-create the issue without it. Make sure to switch runtime to GPU in the settings. Basically, if you stop FeatureAblation part way through, you can see that the GPU's vram increases (doubles!) without being freed. In my colab example, before calling the .attribute(...), around 630 MB of vram is being held. After stopping the .attribute(...) part way through, around 1400 MB of vram is being held. I believe this issue is also present in ShapleyValueSampling and perhaps some other methods as well.

It could be that the vram is freed only whenever the method finishes safely, I am testing that on colab now. Regardless, I don't think this behavior is expected and it seems possible to consume arbitrary vram by calling and stopping one of these methods. This is likely the source of the both problems I was experiencing.

ndalton12 commented 3 years ago

Although I will note, on the actual code I was running originally, the vram usage would explode during the .attribute(...) as well, not just being held after, so maybe there are multiple problems related to the memory here.

NarineK commented 2 years ago

@ndalton12 is this still an open problem ?

ndalton12 commented 2 years ago

@NarineK I haven't had (and won't) time to test this lately, but I assume the issue is still there if the code for those methods hasn't been touched. Otherwise, I'm really not sure. Feel free to close the issue if you feel comfortable.