fastai / fastai_dev

fast.ai early development experiments
Apache License 2.0
639 stars 351 forks source link

learn.validate() blows up with Precision() and Recall() in metrics #230

Closed radekosmulski closed 5 years ago

radekosmulski commented 5 years ago

If I have a trained model and I add new metrics to it as follows:

learn.metrics += [Precision(), Recall()]

and then run

learn.validate()

I get the following error:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-26-631604a2e07b> in <module>
----> 1 learn.validate()

~/work/fastai_dev/dev/local/learner.py in validate(self, dl, cbs)
    228             self(['begin_fit', 'begin_epoch', 'begin_validate'])
    229             self.all_batches()
--> 230             self(['after_validate', 'after_epoch', 'after_fit'])
    231         return self.recorder.values[-1]
    232 

~/work/fastai_dev/dev/local/learner.py in __call__(self, event_name)
    245     def __call__(self, event_name):
    246         "Call `event_name` (one or a list) for all callbacks"
--> 247         for e in L(event_name): self._call_one(e)
    248 
    249     def _call_one(self, event_name):

~/work/fastai_dev/dev/local/learner.py in _call_one(self, event_name)
    249     def _call_one(self, event_name):
    250         assert hasattr(event, event_name)
--> 251         [cb(event_name) for cb in sort_by_run(self.cbs)]
    252 
    253     @contextmanager

~/work/fastai_dev/dev/local/learner.py in <listcomp>(.0)
    249     def _call_one(self, event_name):
    250         assert hasattr(event, event_name)
--> 251         [cb(event_name) for cb in sort_by_run(self.cbs)]
    252 
    253     @contextmanager

~/work/fastai_dev/dev/local/learner.py in __call__(self, event_name)
     21 class Callback():
     22     "Basic class handling tweaks of the training loop by changing a `Learner` in various events"
---> 23     def __call__(self, event_name): getattr(self, event_name, noop)()
     24     def __repr__(self): return self.__class__.__name__
     25     def __getattr__(self, k):

~/work/fastai_dev/dev/local/learner.py in after_validate(self)
    404     def after_train   (self): self.log += [_maybe_item(m.value) for m in self._train_mets]
    405     def begin_validate(self): [m.reset() for m in self._valid_mets]
--> 406     def after_validate(self): self.log += [_maybe_item(m.value) for m in self._valid_mets]
    407 
    408     def after_cancel_train(self):    self.cancel_train = True

~/work/fastai_dev/dev/local/learner.py in <listcomp>(.0)
    404     def after_train   (self): self.log += [_maybe_item(m.value) for m in self._train_mets]
    405     def begin_validate(self): [m.reset() for m in self._valid_mets]
--> 406     def after_validate(self): self.log += [_maybe_item(m.value) for m in self._valid_mets]
    407 
    408     def after_cancel_train(self):    self.cancel_train = True

~/work/fastai_dev/dev/local/learner.py in _maybe_item(t)
    366 
    367 def _maybe_item(t):
--> 368     return t.item() if t.numel()==1 else t
    369 
    370 class Recorder(Callback):

AttributeError: 'numpy.float64' object has no attribute 'numel'

I can 'fix' this using the following: image but this is obviously not the fix we want :slightly_smiling_face:

sgugger commented 5 years ago

Fix is not that different though ;)