timeseriesAI / tsai

Time series Timeseries Deep Learning Machine Learning Python Pytorch fastai | State-of-the-art Deep Learning library for Time Series and Sequences in Pytorch / fastai
https://timeseriesai.github.io/tsai/
Apache License 2.0
4.91k stars 622 forks source link

AttributeError: Exception occured in `ShowGraph` when calling event `after_fit`: 'bool' object has no attribute 'all' #856

Open cversek opened 7 months ago

cversek commented 7 months ago

After applying a workaround for issue #847, continuing to run notebook 01a_MultiClass_MultiLabel_TSClassification.ipynb under the MultiLabel section runs into this error (which is thought not to be related to the previous issue):

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[31], line 2
      1 learn = ts_learner(dls, InceptionTimePlus, metrics=[partial(accuracy_multi, by_sample=True), partial(accuracy_multi, by_sample=False)], cbs=ShowGraph(), d=1)
----> 2 learn.fit_one_cycle(10, lr_max=1e-3)

File ~/mambaforge/envs/neurovep_data/lib/python3.11/site-packages/fastai/callback/schedule.py:119, in fit_one_cycle(self, n_epoch, lr_max, div, div_final, pct_start, wd, moms, cbs, reset_opt, start_epoch)
    116 lr_max = np.array([h['lr'] for h in self.opt.hypers])
    117 scheds = {'lr': combined_cos(pct_start, lr_max/div, lr_max, lr_max/div_final),
    118           'mom': combined_cos(pct_start, *(self.moms if moms is None else moms))}
--> 119 self.fit(n_epoch, cbs=ParamScheduler(scheds)+L(cbs), reset_opt=reset_opt, wd=wd, start_epoch=start_epoch)

File ~/mambaforge/envs/neurovep_data/lib/python3.11/site-packages/fastai/learner.py:264, in Learner.fit(self, n_epoch, lr, wd, cbs, reset_opt, start_epoch)
    262 self.opt.set_hypers(lr=self.lr if lr is None else lr)
    263 self.n_epoch = n_epoch
--> 264 self._with_events(self._do_fit, 'fit', CancelFitException, self._end_cleanup)

File ~/mambaforge/envs/neurovep_data/lib/python3.11/site-packages/fastai/learner.py:201, in Learner._with_events(self, f, event_type, ex, final)
    199 try: self(f'before_{event_type}');  f()
    200 except ex: self(f'after_cancel_{event_type}')
--> 201 self(f'after_{event_type}');  final()

File ~/mambaforge/envs/neurovep_data/lib/python3.11/site-packages/fastai/learner.py:172, in Learner.__call__(self, event_name)
--> 172 def __call__(self, event_name): L(event_name).map(self._call_one)

File ~/mambaforge/envs/neurovep_data/lib/python3.11/site-packages/fastcore/foundation.py:156, in L.map(self, f, *args, **kwargs)
--> 156 def map(self, f, *args, **kwargs): return self._new(map_ex(self, f, *args, gen=False, **kwargs))

File ~/mambaforge/envs/neurovep_data/lib/python3.11/site-packages/fastcore/basics.py:840, in map_ex(iterable, f, gen, *args, **kwargs)
    838 res = map(g, iterable)
    839 if gen: return res
--> 840 return list(res)

File ~/mambaforge/envs/neurovep_data/lib/python3.11/site-packages/fastcore/basics.py:825, in bind.__call__(self, *args, **kwargs)
    823     if isinstance(v,_Arg): kwargs[k] = args.pop(v.i)
    824 fargs = [args[x.i] if isinstance(x, _Arg) else x for x in self.pargs] + args[self.maxi+1:]
--> 825 return self.func(*fargs, **kwargs)

File ~/mambaforge/envs/neurovep_data/lib/python3.11/site-packages/fastai/learner.py:176, in Learner._call_one(self, event_name)
    174 def _call_one(self, event_name):
    175     if not hasattr(event, event_name): raise Exception(f'missing {event_name}')
--> 176     for cb in self.cbs.sorted('order'): cb(event_name)

File ~/mambaforge/envs/neurovep_data/lib/python3.11/site-packages/fastai/callback/core.py:62, in Callback.__call__(self, event_name)
     60     try: res = getcallable(self, event_name)()
     61     except (CancelBatchException, CancelBackwardException, CancelEpochException, CancelFitException, CancelStepException, CancelTrainException, CancelValidException): raise
---> 62     except Exception as e: raise modify_exception(e, f'Exception occured in `{self.__class__.__name__}` when calling event `{event_name}`:\n\t{e.args[0]}', replace=True)
     63 if event_name=='after_fit': self.run=True #Reset self.run to True at each end of fit
     64 return res

File ~/mambaforge/envs/neurovep_data/lib/python3.11/site-packages/fastai/callback/core.py:60, in Callback.__call__(self, event_name)
     58 res = None
     59 if self.run and _run: 
---> 60     try: res = getcallable(self, event_name)()
     61     except (CancelBatchException, CancelBackwardException, CancelEpochException, CancelFitException, CancelStepException, CancelTrainException, CancelValidException): raise
     62     except Exception as e: raise modify_exception(e, f'Exception occured in `{self.__class__.__name__}` when calling event `{event_name}`:\n\t{e.args[0]}', replace=True)

File ~/gitwork/cversek/tsai/tsai/callback/core.py:101, in ShowGraph.after_fit(self)
     99     plt.close(self.graph_ax.figure)
    100 if self.plot_metrics: 
--> 101     self.learn.plot_metrics(final_losses=self.final_losses, perc=self.perc)

File ~/gitwork/cversek/tsai/tsai/learner.py:233, in plot_metrics(self, **kwargs)
    230 @patch
    231 @delegates(subplots)
    232 def plot_metrics(self: Learner, **kwargs):
--> 233     self.recorder.plot_metrics(**kwargs)

File ~/gitwork/cversek/tsai/tsai/learner.py:220, in plot_metrics(self, nrows, ncols, figsize, final_losses, perc, **kwargs)
    218 else:
    219     color = '#ff7f0e'
--> 220     label = 'valid' if (m != [None] * len(m)).all() else None
    221     axs[ax_idx].grid(color='gainsboro', linewidth=.5)
    222 axs[ax_idx].plot(xs, m, color=color, label=label)

AttributeError: Exception occured in `ShowGraph` when calling event `after_fit`:
    'bool' object has no attribute 'all'

Here is the output of my my_setup():

os              : Linux-6.2.0-36-generic-x86_64-with-glibc2.37
python          : 3.11.3
tsai            : 0.3.8
fastai          : 2.7.13
fastcore        : 1.5.29
torch           : 2.0.1
device          : 1 gpu (['NVIDIA GeForce RTX 3090'])
cpu cores       : 24
threads per cpu : 1
RAM             : 125.53 GB
GPU memory      : [24.0] GB

It looks like the error is triggered when 'valid_accuracy_multi' metric is plotted.