Closed PiotrCzapla closed 1 year ago
Check out this pull request on
See visual diffs & provide feedback on Jupyter Notebooks.
Powered by ReviewNB
I was composing example callbacks, and I'm not happy with the results in the end. They are too long to be useful. Maybe it is better to provide the list of executed callbacks proactively when an exception is raised. Here is a concrete example of too complicated traceback: Exception without chain of callbacks:
RuntimeError Traceback (most recent call last)
Cell In[94], line 4
2 cbs = [TrainCB(), metrics, ProgressCB(plot=True)]
3 learn = Learner(model, dls, F.cross_entropy, lr=0.2, cbs=cbs)
----> 4 learn.fit(1)
Cell In[93], line 44, in Learner.fit(self, n_epochs, train, valid, cbs, lr)
42 with self.callback_ctx('fit'):
43 for self.epoch in self.epochs:
---> 44 if train: self.one_epoch(True)
45 if valid: torch.no_grad()(self.one_epoch)(False)
46 finally:
Cell In[93], line 28, in Learner.one_epoch(self, train)
26 for self.iter,self.batch in enumerate(self.dl):
27 with self.callback_ctx('batch'):
---> 28 self.predict()
29 self.get_loss()
30 if self.training:
Cell In[93], line 53, in Learner.callback(self, method_nm, **kw)
---> 53 def callback(self, method_nm, **kw): return run_cbs(self.cbs, method_nm, self, **kw)
Cell In[66], line 17, in run_cbs(cbs, method_nm, learn, run)
---> 17 run(gen)
# this was edited to match original run_cbs
Cell In[34], line 3, in TrainCB.predict(self, learn)
----> 3 def predict(self, learn): learn.preds = learn.model(learn.batch[0])
File /opt/homebrew/Caskroom/miniforge/base/envs/miniai/lib/python3.10/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
1186 # If we don't have any hooks, we want to skip the rest of the logic in
1187 # this function, and just call forward.
1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1189 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190 return forward_call(*input, **kwargs)
1191 # Do not call functions when jit is used
1192 full_backward_hooks, non_full_backward_hooks = [], []
File /opt/homebrew/Caskroom/miniforge/base/envs/miniai/lib/python3.10/site-packages/torch/nn/modules/container.py:204, in Sequential.forward(self, input)
202 def forward(self, input):
203 for module in self:
--> 204 input = module(input)
205 return input
File /opt/homebrew/Caskroom/miniforge/base/envs/miniai/lib/python3.10/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
1186 # If we don't have any hooks, we want to skip the rest of the logic in
1187 # this function, and just call forward.
1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1189 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190 return forward_call(*input, **kwargs)
1191 # Do not call functions when jit is used
1192 full_backward_hooks, non_full_backward_hooks = [], []
File /opt/homebrew/Caskroom/miniforge/base/envs/miniai/lib/python3.10/site-packages/torch/nn/modules/linear.py:114, in Linear.forward(self, input)
113 def forward(self, input: Tensor) -> Tensor:
--> 114 return F.linear(input, self.weight, self.bias)
RuntimeError: Placeholder storage has not been allocated on MPS device!
Becomes this traceback:
RuntimeError Traceback (most recent call last)
Cell In[96], line 4
2 cbs = [TrainCB(), metrics, ProgressCB(plot=True)]
3 learn = Learner(model, dls, F.cross_entropy, lr=0.2, cbs=cbs)
----> 4 learn.fit(1)
Cell In[93], line 42, in Learner.fit(self, n_epochs, train, valid, cbs, lr)
40 self.epochs = range(n_epochs)
41 self.opt = self.opt_func(self.model.parameters(), self.lr if lr is None else lr)
---> 42 with self.callback_ctx('fit'):
43 for self.epoch in self.epochs:
44 if train: self.one_epoch(True)
Cell In[95], line 10, in Callback_ctx.__exit__(self, type, value, tb)
8 if type is None: self.learn.callback(f'after_{self.nm}')
9 elif type is globals()[f'Cancel{self.nm.title()}Exception']: pass
---> 10 else: self.gen.throw(type, value, tb)
11 finally: self.learn.callback(f'cleanup_{self.nm}')
File <string>:1, in MetricsCB__before_fit(f, *args, **kwargs)
Cell In[66], line 9, in _cb_chain(cb, next_cb, learn)
7 else: cb(learn)
8 if next_cb is None: yield id(learn)
----> 9 else: yield from next_cb(learn)
File <string>:1, in ProgressCB__before_fit(f, *args, **kwargs)
Cell In[66], line 8, in _cb_chain(cb, next_cb, learn)
6 if isinstance(cb.__self__, SimpleCB): cb()
7 else: cb(learn)
----> 8 if next_cb is None: yield id(learn)
9 else: yield from next_cb(learn)
Cell In[93], line 44, in Learner.fit(self, n_epochs, train, valid, cbs, lr)
42 with self.callback_ctx('fit'):
43 for self.epoch in self.epochs:
---> 44 if train: self.one_epoch(True)
45 if valid: torch.no_grad()(self.one_epoch)(False)
46 finally:
Cell In[93], line 25, in Learner.one_epoch(self, train)
23 self.model.train(train)
24 self.dl = self.dls.train if train else self.dls.valid
---> 25 with self.callback_ctx('epoch'):
26 for self.iter,self.batch in enumerate(self.dl):
27 with self.callback_ctx('batch'):
Cell In[95], line 10, in Callback_ctx.__exit__(self, type, value, tb)
8 if type is None: self.learn.callback(f'after_{self.nm}')
9 elif type is globals()[f'Cancel{self.nm.title()}Exception']: pass
---> 10 else: self.gen.throw(type, value, tb)
11 finally: self.learn.callback(f'cleanup_{self.nm}')
File <string>:1, in MetricsCB__before_epoch(f, *args, **kwargs)
Cell In[66], line 9, in _cb_chain(cb, next_cb, learn)
7 else: cb(learn)
8 if next_cb is None: yield id(learn)
----> 9 else: yield from next_cb(learn)
File <string>:1, in ProgressCB__before_epoch(f, *args, **kwargs)
Cell In[66], line 8, in _cb_chain(cb, next_cb, learn)
6 if isinstance(cb.__self__, SimpleCB): cb()
7 else: cb(learn)
----> 8 if next_cb is None: yield id(learn)
9 else: yield from next_cb(learn)
Cell In[93], line 27, in Learner.one_epoch(self, train)
25 with self.callback_ctx('epoch'):
26 for self.iter,self.batch in enumerate(self.dl):
---> 27 with self.callback_ctx('batch'):
28 self.predict()
29 self.get_loss()
Cell In[95], line 10, in Callback_ctx.__exit__(self, type, value, tb)
8 if type is None: self.learn.callback(f'after_{self.nm}')
9 elif type is globals()[f'Cancel{self.nm.title()}Exception']: pass
---> 10 else: self.gen.throw(type, value, tb)
11 finally: self.learn.callback(f'cleanup_{self.nm}')
Cell In[66], line 8, in _cb_chain(cb, next_cb, learn)
6 if isinstance(cb.__self__, SimpleCB): cb()
7 else: cb(learn)
----> 8 if next_cb is None: yield id(learn)
9 else: yield from next_cb(learn)
Cell In[93], line 28, in Learner.one_epoch(self, train)
26 for self.iter,self.batch in enumerate(self.dl):
27 with self.callback_ctx('batch'):
---> 28 self.predict()
29 self.get_loss()
30 if self.training:
Cell In[93], line 53, in Learner.callback(self, method_nm, **kw)
---> 53 def callback(self, method_nm, **kw): return run_cbs(self.cbs, method_nm, self, **kw)
Cell In[66], line 17, in run_cbs(cbs, method_nm, learn, run)
15 gen = node(learn) if node else _cb_chain(id, None, learn)
16 try:
---> 17 run(gen)
18 except (CancelBatchException, CancelEpochException, CancelFitException ): raise
19 # except Exception as e:
20 # gen.throw(e)
File <string>:1, in TrainCB__predict(f, *args, **kwargs)
Cell In[66], line 7, in _cb_chain(cb, next_cb, learn)
5 def _cb_chain(cb, next_cb, learn):
6 if isinstance(cb.__self__, SimpleCB): cb()
----> 7 else: cb(learn)
8 if next_cb is None: yield id(learn)
9 else: yield from next_cb(learn)
Cell In[34], line 3, in TrainCB.predict(self, learn)
----> 3 def predict(self, learn): learn.preds = learn.model(learn.batch[0])
File /opt/homebrew/Caskroom/miniforge/base/envs/miniai/lib/python3.10/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
1186 # If we don't have any hooks, we want to skip the rest of the logic in
1187 # this function, and just call forward.
1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1189 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190 return forward_call(*input, **kwargs)
1191 # Do not call functions when jit is used
1192 full_backward_hooks, non_full_backward_hooks = [], []
File /opt/homebrew/Caskroom/miniforge/base/envs/miniai/lib/python3.10/site-packages/torch/nn/modules/container.py:204, in Sequential.forward(self, input)
202 def forward(self, input):
203 for module in self:
--> 204 input = module(input)
205 return input
File /opt/homebrew/Caskroom/miniforge/base/envs/miniai/lib/python3.10/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
1186 # If we don't have any hooks, we want to skip the rest of the logic in
1187 # this function, and just call forward.
1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1189 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190 return forward_call(*input, **kwargs)
1191 # Do not call functions when jit is used
1192 full_backward_hooks, non_full_backward_hooks = [], []
File /opt/homebrew/Caskroom/miniforge/base/envs/miniai/lib/python3.10/site-packages/torch/nn/modules/linear.py:114, in Linear.forward(self, input)
113 def forward(self, input: Tensor) -> Tensor:
--> 114 return F.linear(input, self.weight, self.bias)
RuntimeError: Placeholder storage has not been allocated on MPS device!
I was looking for a dynamic way of adding
yield from ...
at the end of the method. Any attempt at changing f.code /exec makes the source code disappear in %debug, so there is no point in doing so.Since we are not going to add
yield from
/await
to all calls it dynamically, I think it is not worth changingbefore_*/after_*
to before / afteryield from
as it will be used infrequently. This simplified the code even more.I'm making this a PR so the proposal won't get buried away in the discord channel.