fastai / course22p2

course.fast.ai 2022 part 2
https://course.fast.ai/Lessons/part2.html
Apache License 2.0
467 stars 252 forks source link

Put callback in traceback in case of exception. #8

Closed PiotrCzapla closed 1 year ago

PiotrCzapla commented 1 year ago

I was looking for a dynamic way of adding yield from ... at the end of the method. Any attempt at changing f.code /exec makes the source code disappear in %debug, so there is no point in doing so.

Since we are not going to add yield from /await to all calls it dynamically, I think it is not worth changing before_*/after_* to before / after yield from as it will be used infrequently. This simplified the code even more.

I'm making this a PR so the proposal won't get buried away in the discord channel.

review-notebook-app[bot] commented 1 year ago

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

PiotrCzapla commented 1 year ago

I was composing example callbacks, and I'm not happy with the results in the end. They are too long to be useful. Maybe it is better to provide the list of executed callbacks proactively when an exception is raised. Here is a concrete example of too complicated traceback: Exception without chain of callbacks:

RuntimeError                              Traceback (most recent call last)
Cell In[94], line 4
      2 cbs = [TrainCB(),  metrics, ProgressCB(plot=True)]
      3 learn = Learner(model, dls, F.cross_entropy, lr=0.2, cbs=cbs)
----> 4 learn.fit(1)

Cell In[93], line 44, in Learner.fit(self, n_epochs, train, valid, cbs, lr)
     42     with self.callback_ctx('fit'):
     43         for self.epoch in self.epochs:
---> 44             if train: self.one_epoch(True)
     45             if valid: torch.no_grad()(self.one_epoch)(False)
     46 finally:

Cell In[93], line 28, in Learner.one_epoch(self, train)
     26 for self.iter,self.batch in enumerate(self.dl):
     27     with self.callback_ctx('batch'):
---> 28         self.predict()
     29         self.get_loss()
     30         if self.training:

Cell In[93], line 53, in Learner.callback(self, method_nm, **kw)
---> 53 def callback(self, method_nm, **kw): return run_cbs(self.cbs, method_nm, self, **kw)

Cell In[66], line 17, in run_cbs(cbs, method_nm, learn, run)
---> 17     run(gen)
 # this was edited to match original run_cbs

Cell In[34], line 3, in TrainCB.predict(self, learn)
----> 3 def predict(self, learn): learn.preds = learn.model(learn.batch[0])

File /opt/homebrew/Caskroom/miniforge/base/envs/miniai/lib/python3.10/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
   1186 # If we don't have any hooks, we want to skip the rest of the logic in
   1187 # this function, and just call forward.
   1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190     return forward_call(*input, **kwargs)
   1191 # Do not call functions when jit is used
   1192 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/homebrew/Caskroom/miniforge/base/envs/miniai/lib/python3.10/site-packages/torch/nn/modules/container.py:204, in Sequential.forward(self, input)
    202 def forward(self, input):
    203     for module in self:
--> 204         input = module(input)
    205     return input

File /opt/homebrew/Caskroom/miniforge/base/envs/miniai/lib/python3.10/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
   1186 # If we don't have any hooks, we want to skip the rest of the logic in
   1187 # this function, and just call forward.
   1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190     return forward_call(*input, **kwargs)
   1191 # Do not call functions when jit is used
   1192 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/homebrew/Caskroom/miniforge/base/envs/miniai/lib/python3.10/site-packages/torch/nn/modules/linear.py:114, in Linear.forward(self, input)
    113 def forward(self, input: Tensor) -> Tensor:
--> 114     return F.linear(input, self.weight, self.bias)

RuntimeError: Placeholder storage has not been allocated on MPS device!

Becomes this traceback:

RuntimeError                              Traceback (most recent call last)
Cell In[96], line 4
      2 cbs = [TrainCB(),  metrics, ProgressCB(plot=True)]
      3 learn = Learner(model, dls, F.cross_entropy, lr=0.2, cbs=cbs)
----> 4 learn.fit(1)

Cell In[93], line 42, in Learner.fit(self, n_epochs, train, valid, cbs, lr)
     40 self.epochs = range(n_epochs)
     41 self.opt = self.opt_func(self.model.parameters(), self.lr if lr is None else lr)
---> 42 with self.callback_ctx('fit'):
     43     for self.epoch in self.epochs:
     44         if train: self.one_epoch(True)

Cell In[95], line 10, in Callback_ctx.__exit__(self, type, value, tb)
      8     if type is None: self.learn.callback(f'after_{self.nm}')
      9     elif  type is globals()[f'Cancel{self.nm.title()}Exception']: pass
---> 10     else: self.gen.throw(type, value, tb)
     11 finally: self.learn.callback(f'cleanup_{self.nm}')

File <string>:1, in MetricsCB__before_fit(f, *args, **kwargs)

Cell In[66], line 9, in _cb_chain(cb, next_cb, learn)
      7 else: cb(learn)
      8 if next_cb is None: yield id(learn)
----> 9 else:  yield from next_cb(learn)

File <string>:1, in ProgressCB__before_fit(f, *args, **kwargs)

Cell In[66], line 8, in _cb_chain(cb, next_cb, learn)
      6 if isinstance(cb.__self__, SimpleCB): cb()
      7 else: cb(learn)
----> 8 if next_cb is None: yield id(learn)
      9 else:  yield from next_cb(learn)

Cell In[93], line 44, in Learner.fit(self, n_epochs, train, valid, cbs, lr)
     42     with self.callback_ctx('fit'):
     43         for self.epoch in self.epochs:
---> 44             if train: self.one_epoch(True)
     45             if valid: torch.no_grad()(self.one_epoch)(False)
     46 finally:

Cell In[93], line 25, in Learner.one_epoch(self, train)
     23 self.model.train(train)
     24 self.dl = self.dls.train if train else self.dls.valid
---> 25 with self.callback_ctx('epoch'):
     26     for self.iter,self.batch in enumerate(self.dl):
     27         with self.callback_ctx('batch'):

Cell In[95], line 10, in Callback_ctx.__exit__(self, type, value, tb)
      8     if type is None: self.learn.callback(f'after_{self.nm}')
      9     elif  type is globals()[f'Cancel{self.nm.title()}Exception']: pass
---> 10     else: self.gen.throw(type, value, tb)
     11 finally: self.learn.callback(f'cleanup_{self.nm}')

File <string>:1, in MetricsCB__before_epoch(f, *args, **kwargs)

Cell In[66], line 9, in _cb_chain(cb, next_cb, learn)
      7 else: cb(learn)
      8 if next_cb is None: yield id(learn)
----> 9 else:  yield from next_cb(learn)

File <string>:1, in ProgressCB__before_epoch(f, *args, **kwargs)

Cell In[66], line 8, in _cb_chain(cb, next_cb, learn)
      6 if isinstance(cb.__self__, SimpleCB): cb()
      7 else: cb(learn)
----> 8 if next_cb is None: yield id(learn)
      9 else:  yield from next_cb(learn)

Cell In[93], line 27, in Learner.one_epoch(self, train)
     25 with self.callback_ctx('epoch'):
     26     for self.iter,self.batch in enumerate(self.dl):
---> 27         with self.callback_ctx('batch'):
     28             self.predict()
     29             self.get_loss()

Cell In[95], line 10, in Callback_ctx.__exit__(self, type, value, tb)
      8     if type is None: self.learn.callback(f'after_{self.nm}')
      9     elif  type is globals()[f'Cancel{self.nm.title()}Exception']: pass
---> 10     else: self.gen.throw(type, value, tb)
     11 finally: self.learn.callback(f'cleanup_{self.nm}')

Cell In[66], line 8, in _cb_chain(cb, next_cb, learn)
      6 if isinstance(cb.__self__, SimpleCB): cb()
      7 else: cb(learn)
----> 8 if next_cb is None: yield id(learn)
      9 else:  yield from next_cb(learn)

Cell In[93], line 28, in Learner.one_epoch(self, train)
     26 for self.iter,self.batch in enumerate(self.dl):
     27     with self.callback_ctx('batch'):
---> 28         self.predict()
     29         self.get_loss()
     30         if self.training:

Cell In[93], line 53, in Learner.callback(self, method_nm, **kw)
---> 53 def callback(self, method_nm, **kw): return run_cbs(self.cbs, method_nm, self, **kw)

Cell In[66], line 17, in run_cbs(cbs, method_nm, learn, run)
     15 gen = node(learn) if node else _cb_chain(id, None, learn)
     16 try:
---> 17     run(gen)
     18 except (CancelBatchException, CancelEpochException, CancelFitException ): raise
     19 # except Exception as e:
     20 #     gen.throw(e)

File <string>:1, in TrainCB__predict(f, *args, **kwargs)

Cell In[66], line 7, in _cb_chain(cb, next_cb, learn)
      5 def _cb_chain(cb, next_cb, learn):
      6     if isinstance(cb.__self__, SimpleCB): cb()
----> 7     else: cb(learn)
      8     if next_cb is None: yield id(learn)
      9     else:  yield from next_cb(learn)

Cell In[34], line 3, in TrainCB.predict(self, learn)
----> 3 def predict(self, learn): learn.preds = learn.model(learn.batch[0])

File /opt/homebrew/Caskroom/miniforge/base/envs/miniai/lib/python3.10/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
   1186 # If we don't have any hooks, we want to skip the rest of the logic in
   1187 # this function, and just call forward.
   1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190     return forward_call(*input, **kwargs)
   1191 # Do not call functions when jit is used
   1192 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/homebrew/Caskroom/miniforge/base/envs/miniai/lib/python3.10/site-packages/torch/nn/modules/container.py:204, in Sequential.forward(self, input)
    202 def forward(self, input):
    203     for module in self:
--> 204         input = module(input)
    205     return input

File /opt/homebrew/Caskroom/miniforge/base/envs/miniai/lib/python3.10/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
   1186 # If we don't have any hooks, we want to skip the rest of the logic in
   1187 # this function, and just call forward.
   1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190     return forward_call(*input, **kwargs)
   1191 # Do not call functions when jit is used
   1192 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/homebrew/Caskroom/miniforge/base/envs/miniai/lib/python3.10/site-packages/torch/nn/modules/linear.py:114, in Linear.forward(self, input)
    113 def forward(self, input: Tensor) -> Tensor:
--> 114     return F.linear(input, self.weight, self.bias)

RuntimeError: Placeholder storage has not been allocated on MPS device!