airctic / icevision

An Agnostic Computer Vision Framework - Pluggable to any Training Library: Fastai, Pytorch-Lightning with more to come
https://airctic.github.io/icevision/
Apache License 2.0
847 stars 150 forks source link

Training with half precision (learn.to_fp16()) fails #479

Open adamfarquhar opened 3 years ago

adamfarquhar commented 3 years ago

🐛 Bug

Describe the bug When I try to train a Faster RCNN model using IceVision and Fast.AI, it fails with the RuntimeError: Expected object of scalar type c10::Half but got scalar type float for argument 'other'

To Reproduce Steps to reproduce the behavior:

  1. Take a successfully training faster_rcnn.fastai.learner under full precision
  2. Covert it to fp16 with learn.to_fp16()
  3. Train
  4. It signals the above error

Expected behavior Training as normal.

Desktop (please complete the following information):

Additional context

backbone = backbones.resnet_fpn.resnet34(pretrained=True)
model = faster_rcnn.model(num_classes=1+len(labels.component_labels), backbone=backbone)
metrics = [COCOMetric(metric_type=COCOMetricType.bbox)]
RuntimeErrorTraceback (most recent call last)
<ipython-input-13-3d2abe569935> in <module>
      1 learn.to_fp16()
----> 2 learn.fit_one_cycle(epochs,  cbs=[])

/opt/conda/lib/python3.7/site-packages/fastcore/logargs.py in _f(*args, **kwargs)
     54         init_args.update(log)
     55         setattr(inst, 'init_args', init_args)
---> 56         return inst if to_return else f(*args, **kwargs)
     57     return _f

/opt/conda/lib/python3.7/site-packages/fastai/callback/schedule.py in fit_one_cycle(self, n_epoch, lr_max, div, div_final, pct_start, wd, moms, cbs, reset_opt)
    111     scheds = {'lr': combined_cos(pct_start, lr_max/div, lr_max, lr_max/div_final),
    112               'mom': combined_cos(pct_start, *(self.moms if moms is None else moms))}
--> 113     self.fit(n_epoch, cbs=ParamScheduler(scheds)+L(cbs), reset_opt=reset_opt, wd=wd)
    114 
    115 # Cell

/opt/conda/lib/python3.7/site-packages/fastcore/logargs.py in _f(*args, **kwargs)
     54         init_args.update(log)
     55         setattr(inst, 'init_args', init_args)
---> 56         return inst if to_return else f(*args, **kwargs)
     57     return _f

/opt/conda/lib/python3.7/site-packages/fastai/learner.py in fit(self, n_epoch, lr, wd, cbs, reset_opt)
    205             self.opt.set_hypers(lr=self.lr if lr is None else lr)
    206             self.n_epoch = n_epoch
--> 207             self._with_events(self._do_fit, 'fit', CancelFitException, self._end_cleanup)
    208 
    209     def _end_cleanup(self): self.dl,self.xb,self.yb,self.pred,self.loss = None,(None,),(None,),None,None

/opt/conda/lib/python3.7/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    153 
    154     def _with_events(self, f, event_type, ex, final=noop):
--> 155         try:       self(f'before_{event_type}')       ;f()
    156         except ex: self(f'after_cancel_{event_type}')
    157         finally:   self(f'after_{event_type}')        ;final()

/opt/conda/lib/python3.7/site-packages/fastai/learner.py in _do_fit(self)
    195         for epoch in range(self.n_epoch):
    196             self.epoch=epoch
--> 197             self._with_events(self._do_epoch, 'epoch', CancelEpochException)
    198 
    199     @log_args(but='cbs')

/opt/conda/lib/python3.7/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    153 
    154     def _with_events(self, f, event_type, ex, final=noop):
--> 155         try:       self(f'before_{event_type}')       ;f()
    156         except ex: self(f'after_cancel_{event_type}')
    157         finally:   self(f'after_{event_type}')        ;final()

/opt/conda/lib/python3.7/site-packages/fastai/learner.py in _do_epoch(self)
    189 
    190     def _do_epoch(self):
--> 191         self._do_epoch_train()
    192         self._do_epoch_validate()
    193 

/opt/conda/lib/python3.7/site-packages/fastai/learner.py in _do_epoch_train(self)
    181     def _do_epoch_train(self):
    182         self.dl = self.dls.train
--> 183         self._with_events(self.all_batches, 'train', CancelTrainException)
    184 
    185     def _do_epoch_validate(self, ds_idx=1, dl=None):

/opt/conda/lib/python3.7/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    153 
    154     def _with_events(self, f, event_type, ex, final=noop):
--> 155         try:       self(f'before_{event_type}')       ;f()
    156         except ex: self(f'after_cancel_{event_type}')
    157         finally:   self(f'after_{event_type}')        ;final()

/opt/conda/lib/python3.7/site-packages/fastai/learner.py in all_batches(self)
    159     def all_batches(self):
    160         self.n_iter = len(self.dl)
--> 161         for o in enumerate(self.dl): self.one_batch(*o)
    162 
    163     def _do_one_batch(self):

/opt/conda/lib/python3.7/site-packages/fastai/learner.py in one_batch(self, i, b)
    177         self.iter = i
    178         self._split(b)
--> 179         self._with_events(self._do_one_batch, 'batch', CancelBatchException)
    180 
    181     def _do_epoch_train(self):

/opt/conda/lib/python3.7/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    153 
    154     def _with_events(self, f, event_type, ex, final=noop):
--> 155         try:       self(f'before_{event_type}')       ;f()
    156         except ex: self(f'after_cancel_{event_type}')
    157         finally:   self(f'after_{event_type}')        ;final()

/opt/conda/lib/python3.7/site-packages/fastai/learner.py in _do_one_batch(self)
    162 
    163     def _do_one_batch(self):
--> 164         self.pred = self.model(*self.xb)
    165         self('after_pred')
    166         if len(self.yb): self.loss = self.loss_func(self.pred, *self.yb)

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

/opt/conda/lib/python3.7/site-packages/torchvision/models/detection/generalized_rcnn.py in forward(self, images, targets)
     96         if isinstance(features, torch.Tensor):
     97             features = OrderedDict([('0', features)])
---> 98         proposals, proposal_losses = self.rpn(images, features, targets)
     99         detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
    100         detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

/opt/conda/lib/python3.7/site-packages/torchvision/models/detection/rpn.py in forward(self, images, features, targets)
    496         if self.training:
    497             assert targets is not None
--> 498             labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets)
    499             regression_targets = self.box_coder.encode(matched_gt_boxes, anchors)
    500             loss_objectness, loss_rpn_box_reg = self.compute_loss(

/opt/conda/lib/python3.7/site-packages/torchvision/models/detection/rpn.py in assign_targets_to_anchors(self, anchors, targets)
    338                 labels_per_image = torch.zeros((anchors_per_image.shape[0],), dtype=torch.float32, device=device)
    339             else:
--> 340                 match_quality_matrix = box_ops.box_iou(gt_boxes, anchors_per_image)
    341                 matched_idxs = self.proposal_matcher(match_quality_matrix)
    342                 # get the targets corresponding GT for each proposal

/opt/conda/lib/python3.7/site-packages/torchvision/ops/boxes.py in box_iou(boxes1, boxes2)
    167     area2 = box_area(boxes2)
    168 
--> 169     lt = torch.max(boxes1[:, None, :2], boxes2[:, :2])  # [N,M,2]
    170     rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])  # [N,M,2]
    171 

RuntimeError: Expected object of scalar type c10::Half but got scalar type float for argument 'other'
oke-aditya commented 3 years ago

My thoughts, learner.fp16 is making every internal function to accept tensors of different dtype.

I recommend to use PyTorch Lightning trainer and have a go with fp16. Since I am very much sure that torchvision allows mixed precision training for FRCNN.

Definetely the error stack from rpn is unrelated to torchvision. I think it is fastai issue that mishandles these tensors.

A minimal example is here

import torch, torchvision

device = torch.device('cuda')

model = torchvision.models.detection.fasterrcnn_resnet50_fpn()
model.to(device)

input = [torch.rand(3, 300, 400, device=device)]
boxes = torch.rand((5, 4), dtype=torch.float32, device=device)
boxes[:, 2:] += boxes[:, :2]
target = [{"boxes": boxes,
          "labels": torch.zeros(5, dtype=torch.int64, device=device),
          "image_id": 4,
          "area": torch.zeros(5, dtype=torch.float32, device=device),
          "iscrowd": torch.zeros((5,), dtype=torch.int64, device=device)}]

# use automatic mixed precision
with torch.cuda.amp.autocast():
    loss_dict = model(input, target)
losses = sum(loss for loss in loss_dict.values())
# perform backward outside of autocast context manager
losses.backward()

Here too the FRCNN calls the rpn.py code, which contains these functions. Also the dtype doesn't matter for box_iou or any box operations, since it expects tensor which can be any floating point tensor.

Can you train once with lightning and have a go? If error still persists let me know, the above pytorch code will defintely work fine though.

Anjum48 commented 2 years ago

I can confirm that the same issue exists using PyTorch Lightning.

Steps to reproduce:

  1. Open getting_started_object_detection.ipynb
  2. Add precision=16 to the pl.Trainer args
FraPochetti commented 2 years ago

@ai-fast-track it's very likely a lot of effort to fix this. Shall I close?

FraPochetti commented 2 years ago

@ai-fast-track wdyt?