airctic / icevision

An Agnostic Computer Vision Framework - Pluggable to any Training Library: Fastai, Pytorch-Lightning with more to come
https://airctic.github.io/icevision/
Apache License 2.0
847 stars 150 forks source link

Wandb + fastai + RCNN error #490

Closed lgvaz closed 3 years ago

lgvaz commented 3 years ago

🐛 Bug

Describe the bug An error is raised (strack trace below) when training a rcnn model with fastai and wandb.

To Reproduce Train a RCNN model with fastai and the WandbCallback.

Expected behavior A clear and concise description of what you expected to happen.

Why is this happening I believe the RCNNCallback and the WandbCallback are in conflict somehow. This error does not happen if the WandbCallback is used or if efficientdet is used instead of rcnn.

Stack Trace

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    734                 else:
--> 735                     var = var[0]
    736             grad_fn = var.grad_fn

TypeError: 'ImageList' object does not support indexing
...

During handling of the above exception, another exception occurred:

/usr/local/lib/python3.6/dist-packages/fastcore/foundation.py in __getattr__(self, k)
    157         if self._component_attr_filter(k):
    158             attr = getattr(self,self._default,None)
--> 159             if attr is not None: return getattr(attr,k)
    160         raise AttributeError(k)
    161     def __dir__(self): return custom_dir(self,self._dir())

AttributeError: 'Learner' object has no attribute 'smooth_loss'
FraPochetti commented 3 years ago

Is it ok if I look into this one as well (before diving together into KeyPoint detection)? I see it is already assigned to you @lgvaz

lgvaz commented 3 years ago

Perfect! I assigned it to both of us, I think this might be a tricky one to resolve (hopefully it's easy hahahah).

I have some ideas about what is happening, we can talk about it tomorrow as well, but as I said in the issue, it's about how both callbacks are interacting with each other

FraPochetti commented 3 years ago

I will make sure to take a look at it before tomorrow's chat ;)

lgvaz commented 3 years ago

This error might be instead being caused by replacing learn.loss attribute in models/rcnn/fastai/learner.py:

    # HACK: patch AvgLoss (in original, find_bs gives errors)
    class RCNNAvgLoss(fastai.AvgLoss):
        def accumulate(self, learn):
            bs = len(learn.yb)
            self.total += fastai.to_detach(learn.loss.mean()) * bs
            self.count += bs

    recorder = [cb for cb in learn.cbs if isinstance(cb, fastai.Recorder)][0]
    recorder.loss = RCNNAvgLoss()
FraPochetti commented 3 years ago

While training an RCNN learner with fastai and the WandbCallback, the following exception is raised:

TypeError                                 Traceback (most recent call last)
~/miniconda3/envs/wandb/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
--> 156         try:       self(f'before_{event_type}')       ;f()
    157         except ex: self(f'after_cancel_{event_type}')

~/miniconda3/envs/wandb/lib/python3.8/site-packages/fastai/learner.py in _do_one_batch(self)
--> 167         self.pred = self.model(*self.xb)
    168 

~/miniconda3/envs/wandb/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(

~/miniconda3/envs/wandb/lib/python3.8/site-packages/torchvision/models/detection/generalized_rcnn.py in forward(self, images, targets)
     78 
---> 79         images, targets = self.transform(images, targets)
     80 

~/miniconda3/envs/wandb/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    734                 else:
--> 735                     var = var[0]
    736             grad_fn = var.grad_fn

TypeError: 'ImageList' object is not subscriptable

The issue is triggered by the fastai WandbCallback which invokes wandb.watch(self.learn.model, log=self.log). This method from the wandb library is used to create hooks into the trained model to log relevant metrics during the forward and backward passes (losses, gradients, etc). This behavior interferes with the first building block of the torchvision implementation of an RCNN model, the GeneralizedRCNNTransform

(transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )

which is tasked to preprocess the batch before passing it to the model's backbone, returning an ImageList object. While running the model's forward pass, the presence of wandb hooks triggers the execution of the following snippet from torch/nn/modules/module.py

        if (len(self._backward_hooks) > 0) or (len(_global_backward_hooks) > 0):
            var = result
            while not isinstance(var, torch.Tensor):
                if isinstance(var, dict):
                    var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
                else:
                    var = var[0]

which throws the exception. The temporary solution we implemented consists of monkey-patching the wandb.watch function, replacing it with a no-op, ONLY if an RCNN learner is being trained. This does not disrupt the standard W&B logging and prevents the training process to fail.