fastai / fastai

The fastai deep learning library
http://docs.fast.ai
Apache License 2.0
26.23k stars 7.56k forks source link

Segmentation error: no implementation found for 'torch.nn.functional.cross_entropy' on types that implement torch_function #3022

Closed mulac closed 3 years ago

mulac commented 3 years ago

Please confirm you have the latest versions of fastai, fastcore, fastscript, and nbdev prior to reporting a bug (delete one): YES

Describe the bug In the 23_tutorial.vision.ipynb notebook, it fails to train the Segmentation model in the section titled: "Segmentation - Using the high-level API"

To Reproduce Steps to reproduce the behavior:

  1. Open 23_tutorial.vision.ipynb in colab
  2. Run all

Expected behavior The model to train successfully after calling learn.fine_tune()

Error TypeError: no implementation found for 'torch.nn.functional.cross_entropy' on types that implement __torch_function__: [<class 'fastai.torch_core.TensorImage'>, <class 'fastai.torch_core.TensorMask'>]

click to view full stack trace
``` --------------------------------------------------------------------------- TypeError Traceback (most recent call last) in () 1 learn = unet_learner(dls, resnet34) ----> 2 learn.fine_tune(8) 17 frames /usr/local/lib/python3.6/dist-packages/fastai/callback/schedule.py in fine_tune(self, epochs, base_lr, freeze_epochs, lr_mult, pct_start, div, **kwargs) 155 "Fine tune with `freeze` for `freeze_epochs` then with `unfreeze` from `epochs` using discriminative LR" 156 self.freeze() --> 157 self.fit_one_cycle(freeze_epochs, slice(base_lr), pct_start=0.99, **kwargs) 158 base_lr /= 2 159 self.unfreeze() /usr/local/lib/python3.6/dist-packages/fastai/callback/schedule.py in fit_one_cycle(self, n_epoch, lr_max, div, div_final, pct_start, wd, moms, cbs, reset_opt) 110 scheds = {'lr': combined_cos(pct_start, lr_max/div, lr_max, lr_max/div_final), 111 'mom': combined_cos(pct_start, *(self.moms if moms is None else moms))} --> 112 self.fit(n_epoch, cbs=ParamScheduler(scheds)+L(cbs), reset_opt=reset_opt, wd=wd) 113 114 # Cell /usr/local/lib/python3.6/dist-packages/fastai/learner.py in fit(self, n_epoch, lr, wd, cbs, reset_opt) 203 self.opt.set_hypers(lr=self.lr if lr is None else lr) 204 self.n_epoch = n_epoch --> 205 self._with_events(self._do_fit, 'fit', CancelFitException, self._end_cleanup) 206 207 def _end_cleanup(self): self.dl,self.xb,self.yb,self.pred,self.loss = None,(None,),(None,),None,None /usr/local/lib/python3.6/dist-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final) 152 153 def _with_events(self, f, event_type, ex, final=noop): --> 154 try: self(f'before_{event_type}') ;f() 155 except ex: self(f'after_cancel_{event_type}') 156 finally: self(f'after_{event_type}') ;final() /usr/local/lib/python3.6/dist-packages/fastai/learner.py in _do_fit(self) 194 for epoch in range(self.n_epoch): 195 self.epoch=epoch --> 196 self._with_events(self._do_epoch, 'epoch', CancelEpochException) 197 198 def fit(self, n_epoch, lr=None, wd=None, cbs=None, reset_opt=False): /usr/local/lib/python3.6/dist-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final) 152 153 def _with_events(self, f, event_type, ex, final=noop): --> 154 try: self(f'before_{event_type}') ;f() 155 except ex: self(f'after_cancel_{event_type}') 156 finally: self(f'after_{event_type}') ;final() /usr/local/lib/python3.6/dist-packages/fastai/learner.py in _do_epoch(self) 188 189 def _do_epoch(self): --> 190 self._do_epoch_train() 191 self._do_epoch_validate() 192 /usr/local/lib/python3.6/dist-packages/fastai/learner.py in _do_epoch_train(self) 180 def _do_epoch_train(self): 181 self.dl = self.dls.train --> 182 self._with_events(self.all_batches, 'train', CancelTrainException) 183 184 def _do_epoch_validate(self, ds_idx=1, dl=None): /usr/local/lib/python3.6/dist-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final) 152 153 def _with_events(self, f, event_type, ex, final=noop): --> 154 try: self(f'before_{event_type}') ;f() 155 except ex: self(f'after_cancel_{event_type}') 156 finally: self(f'after_{event_type}') ;final() /usr/local/lib/python3.6/dist-packages/fastai/learner.py in all_batches(self) 158 def all_batches(self): 159 self.n_iter = len(self.dl) --> 160 for o in enumerate(self.dl): self.one_batch(*o) 161 162 def _do_one_batch(self): /usr/local/lib/python3.6/dist-packages/fastai/learner.py in one_batch(self, i, b) 176 self.iter = i 177 self._split(b) --> 178 self._with_events(self._do_one_batch, 'batch', CancelBatchException) 179 180 def _do_epoch_train(self): /usr/local/lib/python3.6/dist-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final) 152 153 def _with_events(self, f, event_type, ex, final=noop): --> 154 try: self(f'before_{event_type}') ;f() 155 except ex: self(f'after_cancel_{event_type}') 156 finally: self(f'after_{event_type}') ;final() /usr/local/lib/python3.6/dist-packages/fastai/learner.py in _do_one_batch(self) 163 self.pred = self.model(*self.xb) 164 self('after_pred') --> 165 if len(self.yb): self.loss = self.loss_func(self.pred, *self.yb) 166 self('after_loss') 167 if not self.training or not len(self.yb): return /usr/local/lib/python3.6/dist-packages/fastai/losses.py in __call__(self, inp, targ, **kwargs) 31 if targ.dtype in [torch.int8, torch.int16, torch.int32]: targ = targ.long() 32 if self.flatten: inp = inp.view(-1,inp.shape[-1]) if self.is_2d else inp.view(-1) ---> 33 return self.func.__call__(inp, targ.view(-1) if self.flatten else targ, **kwargs) 34 35 # Cell /usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 725 result = self._slow_forward(*input, **kwargs) 726 else: --> 727 result = self.forward(*input, **kwargs) 728 for hook in itertools.chain( 729 _global_forward_hooks.values(), /usr/local/lib/python3.6/dist-packages/torch/nn/modules/loss.py in forward(self, input, target) 960 def forward(self, input: Tensor, target: Tensor) -> Tensor: 961 return F.cross_entropy(input, target, weight=self.weight, --> 962 ignore_index=self.ignore_index, reduction=self.reduction) 963 964 /usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction) 2463 cross_entropy, tens_ops, input, target, weight=weight, 2464 size_average=size_average, ignore_index=ignore_index, reduce=reduce, -> 2465 reduction=reduction) 2466 if size_average is not None or reduce is not None: 2467 reduction = _Reduction.legacy_get_string(size_average, reduce) /usr/local/lib/python3.6/dist-packages/torch/overrides.py in handle_torch_function(public_api, relevant_args, *args, **kwargs) 1069 raise TypeError("no implementation found for '{}' on types that implement " 1070 '__torch_function__: {}' -> 1071 .format(func_name, list(map(type, overloaded_args)))) 1072 1073 def has_torch_function(relevant_args: Iterable[Any]) -> bool: TypeError: no implementation found for 'torch.nn.functional.cross_entropy' on types that implement __torch_function__: [, ] ```
jph00 commented 3 years ago

Many thanks for this very clear bug report!