Closed Borda closed 4 years ago
There are some hooks in both ModelHooks
and Callback
that are redundant I think.
on_epoch_start
, on_epoch_end
, on_batch_start
, on_batch_end
.
These already have their alternatives with on_train_*
prefix.
Should we remove them?
from unittest.mock import MagicMock, call, ANY
from pytorch_lightning import Trainer, LightningModule
from tests.base import EvalModelTemplate
from unittest import mock
@mock.patch("torch.save") # need to mock torch.save or we get pickle error
def test_callback_system(torch_save):
model = EvalModelTemplate()
# pretend to be a callback, record all calls
callback = MagicMock()
trainer = Trainer(callbacks=[callback], max_steps=1, num_sanity_val_steps=0)
trainer.fit(model)
# check if a method was called exactly once
callback.on_fit_start.assert_called_once()
# check how many times a method was called
assert callback.on_train_batch_end.call_count == 1
# check that a method was NEVER called
callback.on_keyboard_interrupt.assert_not_called()
# check with what a method was called
callback.on_fit_end.assert_called_with(trainer, model)
# check exact call order
callback.assert_has_calls([
call.on_init_start(trainer),
call.on_init_end(trainer),
call.setup(trainer, None, "fit"),
call.on_fit_start(trainer, model),
call.on_pretrain_routine_start(trainer, model),
call.on_pretrain_routine_end(trainer, model),
call.on_train_start(trainer, model),
call.on_epoch_start(trainer, model),
call.on_train_epoch_start(trainer, model),
# BATCH 0
call.on_batch_start(trainer, model),
# here we don't care about exact values in batch, so we say ANY
call.on_train_batch_start(trainer, model, ANY, 0, 0),
call.on_batch_end(trainer, model),
# here we don't care about exact values in batch, so we say ANY
call.on_train_batch_end(trainer, model, [], ANY, 0, 0),
call.on_epoch_end(trainer, model),
call.on_train_epoch_end(trainer, model, ANY),
call.on_save_checkpoint(trainer, model),
call.on_save_checkpoint().__bool__(), # what's this lol?
call.on_train_end(trainer, model),
call.on_fit_end(trainer, model),
call.teardown(trainer, model, "fit"),
])
Here is a simple example of how to track calls with unittest.mock. It is very elegant, easy to understand and allows you to check that methods were called with the expected arguments and in the exact order.
Please consider testing the callbacks this way. The same could be applied to the model hooks (#4010). It is very straight forward and also easier to read than the old test.
@awaelchli you working on this? Also, should we deprecate/remove https://github.com/PyTorchLightning/pytorch-lightning/issues/4049#issuecomment-706575159 first and then update the test?
what I posted here is all I worked on. It is a fully functional test + demo of other functionalities. it can be extended a little bit and then replace all the custom callback tracking in the old test. Please feel free to take it and use this code. If not, I will find some time.
Also, should we deprecate/remove #4049 (comment) first and then update the test?
I believe hooks like on_epoch_start
can be useful if we "redefine" them to be running on epoch start regardless of training, validation, or test. If this is not desired, I'd rather have them removed.
This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!
It is very elegant, easy to understand and allows you to check that methods were called with the expected arguments and in the exact order.
I like that you can also test the called arguments but it seems that this assert_has_calls
is order independent as the test passes when I shuffle calls inside or remove all to have callback.assert_has_calls([])
instead
UPDATE: sem that we can list all methods
assert callback_mock.method_calls == [
call.on_init_start(trainer),
call.on_init_end(trainer),
]
That's strange, because docs say it asserts the exact order in sequence: https://docs.python.org/3/library/unittest.mock.html#unittest.mock.Mock.assert_has_calls
That's strange, because docs say it asserts the exact order in sequence: https://docs.python.org/3/library/unittest.mock.html#unittest.mock.Mock.assert_has_calls
maybe some bug in implementation but is you simply test the replacement with an empty list, it passes too
yes it makes sense that it passes with empty list because any number of calls can occur before or after the sequence you pass in, so for example,
before()
a()
b()
c()
after()
assert_has_calls([]) # true
assert_has_calls([a]) # true
assert_has_calls([a, b]) # true
assert_has_calls([a, b, c]) # true
assert_has_calls([b, a, c]) # FALSE!!
assert_has_calls([before, a, b, c, after]) # true
Thanks for taking care of this @Borda !
yes it makes sense that it passes with an empty list because any number of calls can occur before or after the sequence you pass in, so for example,
I see, it is a bit dangerous in case you miss some at the beginning or at the end and all seems to be fine... 8-)
🚀 Feature
Reopen #4009 and let it merge...
Motivation
The actual test with bools is weak, do not check call order