rwightman / efficientdet-pytorch

A PyTorch impl of EfficientDet faithful to the original Google impl w/ ported weights
Apache License 2.0
1.58k stars 293 forks source link

Can't use Pytorch backward hooks because of inplace ReLU6 #298

Open MaximeDeloche opened 10 months ago

MaximeDeloche commented 10 months ago

Hi,

I face an issue while trying to add backward hooks to a model that includes a BiFPN (to generate some gradient plots). I get the following error when calling the forward of my BiFpn object:

[...]
  File "/home/mdeloche/.venv/main/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/mdeloche/.venv/main/lib/python3.10/site-packages/effdet/efficientdet.py", line 460, in forward
    x = self.cell(x)
  File "/home/mdeloche/.venv/main/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/mdeloche/.venv/main/lib/python3.10/site-packages/effdet/efficientdet.py", line 38, in forward
    x = module(x)
  File "/home/mdeloche/.venv/main/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/mdeloche/.venv/main/lib/python3.10/site-packages/effdet/efficientdet.py", line 393, in forward
    x.append(fn(x))
  File "/home/mdeloche/.venv/main/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/mdeloche/.venv/main/lib/python3.10/site-packages/effdet/efficientdet.py", line 325, in forward
    return self.after_combine(self.combine(x))
  File "/home/mdeloche/.venv/main/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/mdeloche/.venv/main/lib/python3.10/site-packages/torch/nn/modules/container.py", line 217, in forward
    input = module(input)
  File "/home/mdeloche/.venv/main/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/mdeloche/.venv/main/lib/python3.10/site-packages/torch/nn/modules/activation.py", line 235, in forward
    return F.hardtanh(input, self.min_val, self.max_val, self.inplace)
  File "/home/mdeloche/.venv/main/lib/python3.10/site-packages/torch/nn/functional.py", line 1506, in hardtanh
    result = torch._C._nn.hardtanh_(input, min_val, max_val)
RuntimeError: Output 0 of BackwardHookFunctionBackward is a view and is being modified inplace. 
This view was created inside a custom Function (or because an input was returned as-is) and the autograd 
logic to handle view+inplace would override the custom backward associated with the custom Function, 
leading to incorrect gradients. This behavior is forbidden. You can fix this by cloning the output of the custom Function.

Here's a sample that reproduces the error:

import torch
from effdet import EfficientDet, get_efficientdet_config

class LayerHook:
    def __init__(self, layer) -> None:
        self.backward_hook = layer.register_full_backward_hook(self.hook_b)
        self.backward_output = None

    def hook_b(self, module, inp, out):
        pass

    def __enter__(self, *args):
        return self

    def __exit__(self, *args):
        self.backward_hook.remove()

if __name__ == "__main__":
    config = get_efficientdet_config("tf_efficientdet_lite0")
    config.update({"image_size": (128, 128)})

    model = EfficientDet(config)

    layer = model.fpn.cell[0].fnode[-1].combine

    with LayerHook(layer) as hook:
        output = model(torch.zeros((8, 3, 128, 128)))

I tracked it down to the following line, where the torch.nn.ReLU6 is created with inplace=True. Switching that to False fixes the error.

https://github.com/rwightman/efficientdet-pytorch/blob/d43c9e34cd62d22b4205831bb735f6dd83b8e881/effdet/efficientdet.py#L381

Is it for performance that these operations are created in place? Does someone see a fix for that issue, besides creating these activations with inplace=False?

Thanks in advance 👍

rwightman commented 4 months ago

Sory for the very slow response, could modify the model after creating

def clear_inplace(module):
    res = module
    if hasattr(module, 'inplace'):
        module.inplace = False
    else:
        for name, child in module.named_children():
            new_child = clear_inplace(child)
    return res