Closed lgvaz closed 3 years ago
Is it ok if I look into this one as well (before diving together into KeyPoint detection)? I see it is already assigned to you @lgvaz
Perfect! I assigned it to both of us, I think this might be a tricky one to resolve (hopefully it's easy hahahah).
I have some ideas about what is happening, we can talk about it tomorrow as well, but as I said in the issue, it's about how both callbacks are interacting with each other
I will make sure to take a look at it before tomorrow's chat ;)
This error might be instead being caused by replacing learn.loss
attribute in models/rcnn/fastai/learner.py
:
# HACK: patch AvgLoss (in original, find_bs gives errors)
class RCNNAvgLoss(fastai.AvgLoss):
def accumulate(self, learn):
bs = len(learn.yb)
self.total += fastai.to_detach(learn.loss.mean()) * bs
self.count += bs
recorder = [cb for cb in learn.cbs if isinstance(cb, fastai.Recorder)][0]
recorder.loss = RCNNAvgLoss()
While training an RCNN learner with fastai and the WandbCallback, the following exception is raised:
TypeError Traceback (most recent call last)
~/miniconda3/envs/wandb/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
--> 156 try: self(f'before_{event_type}') ;f()
157 except ex: self(f'after_cancel_{event_type}')
~/miniconda3/envs/wandb/lib/python3.8/site-packages/fastai/learner.py in _do_one_batch(self)
--> 167 self.pred = self.model(*self.xb)
168
~/miniconda3/envs/wandb/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
721 else:
--> 722 result = self.forward(*input, **kwargs)
723 for hook in itertools.chain(
~/miniconda3/envs/wandb/lib/python3.8/site-packages/torchvision/models/detection/generalized_rcnn.py in forward(self, images, targets)
78
---> 79 images, targets = self.transform(images, targets)
80
~/miniconda3/envs/wandb/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
734 else:
--> 735 var = var[0]
736 grad_fn = var.grad_fn
TypeError: 'ImageList' object is not subscriptable
The issue is triggered by the fastai WandbCallback which invokes wandb.watch(self.learn.model, log=self.log)
.
This method from the wandb
library is used to create hooks into the trained model to log relevant metrics during the forward and backward passes (losses, gradients, etc).
This behavior interferes with the first building block of the torchvision
implementation of an RCNN model, the GeneralizedRCNNTransform
(transform): GeneralizedRCNNTransform(
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
Resize(min_size=(800,), max_size=1333, mode='bilinear')
)
which is tasked to preprocess the batch before passing it to the model's backbone, returning an ImageList
object. While running the model's forward pass, the presence of wandb
hooks triggers the execution of the following snippet from torch/nn/modules/module.py
if (len(self._backward_hooks) > 0) or (len(_global_backward_hooks) > 0):
var = result
while not isinstance(var, torch.Tensor):
if isinstance(var, dict):
var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
else:
var = var[0]
which throws the exception.
The temporary solution we implemented consists of monkey-patching the wandb.watch
function, replacing it with a no-op, ONLY if an RCNN learner is being trained. This does not disrupt the standard W&B logging and prevents the training process to fail.
🐛 Bug
Describe the bug An error is raised (strack trace below) when training a rcnn model with fastai and wandb.
To Reproduce Train a RCNN model with fastai and the WandbCallback.
Expected behavior A clear and concise description of what you expected to happen.
Why is this happening I believe the
RCNNCallback
and theWandbCallback
are in conflict somehow. This error does not happen if theWandbCallback
is used or ifefficientdet
is used instead of rcnn.Stack Trace