Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.53k stars 3.39k forks source link

Better/correct test for callback #4049

Closed Borda closed 4 years ago

Borda commented 4 years ago

🚀 Feature

Reopen #4009 and let it merge...

Motivation

The actual test with bools is weak, do not check call order

rohitgr7 commented 4 years ago

There are some hooks in both ModelHooks and Callback that are redundant I think. on_epoch_start, on_epoch_end, on_batch_start, on_batch_end.

These already have their alternatives with on_train_* prefix.

Should we remove them?

awaelchli commented 4 years ago
from unittest.mock import MagicMock, call, ANY
from pytorch_lightning import Trainer, LightningModule
from tests.base import EvalModelTemplate
from unittest import mock

@mock.patch("torch.save")  # need to mock torch.save or we get pickle error
def test_callback_system(torch_save):
    model = EvalModelTemplate()
    # pretend to be a callback, record all calls
    callback = MagicMock()
    trainer = Trainer(callbacks=[callback], max_steps=1, num_sanity_val_steps=0)
    trainer.fit(model)

    # check if a method was called exactly once
    callback.on_fit_start.assert_called_once()

    # check how many times a method was called
    assert callback.on_train_batch_end.call_count == 1

    # check that a method was NEVER called
    callback.on_keyboard_interrupt.assert_not_called()

    # check with what a method was called
    callback.on_fit_end.assert_called_with(trainer, model)

    # check exact call order
    callback.assert_has_calls([
        call.on_init_start(trainer),
        call.on_init_end(trainer),
        call.setup(trainer, None, "fit"),
        call.on_fit_start(trainer, model),
        call.on_pretrain_routine_start(trainer, model),
        call.on_pretrain_routine_end(trainer, model),
        call.on_train_start(trainer, model),
        call.on_epoch_start(trainer, model),
        call.on_train_epoch_start(trainer, model),
        # BATCH 0
        call.on_batch_start(trainer, model),
        # here we don't care about exact values in batch, so we say ANY
        call.on_train_batch_start(trainer, model, ANY, 0, 0),
        call.on_batch_end(trainer, model),
        # here we don't care about exact values in batch, so we say ANY
        call.on_train_batch_end(trainer, model, [], ANY, 0, 0),
        call.on_epoch_end(trainer, model),
        call.on_train_epoch_end(trainer, model, ANY),
        call.on_save_checkpoint(trainer, model),
        call.on_save_checkpoint().__bool__(),   # what's this lol?
        call.on_train_end(trainer, model),
        call.on_fit_end(trainer, model),
        call.teardown(trainer, model, "fit"),
    ])

Here is a simple example of how to track calls with unittest.mock. It is very elegant, easy to understand and allows you to check that methods were called with the expected arguments and in the exact order.

Please consider testing the callbacks this way. The same could be applied to the model hooks (#4010). It is very straight forward and also easier to read than the old test.

awaelchli commented 4 years ago

https://docs.python.org/3/library/unittest.mock.html#module-unittest.mock

rohitgr7 commented 4 years ago

@awaelchli you working on this? Also, should we deprecate/remove https://github.com/PyTorchLightning/pytorch-lightning/issues/4049#issuecomment-706575159 first and then update the test?

awaelchli commented 4 years ago

what I posted here is all I worked on. It is a fully functional test + demo of other functionalities. it can be extended a little bit and then replace all the custom callback tracking in the old test. Please feel free to take it and use this code. If not, I will find some time.

Also, should we deprecate/remove #4049 (comment) first and then update the test?

I believe hooks like on_epoch_start can be useful if we "redefine" them to be running on epoch start regardless of training, validation, or test. If this is not desired, I'd rather have them removed.

stale[bot] commented 4 years ago

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

Borda commented 4 years ago

It is very elegant, easy to understand and allows you to check that methods were called with the expected arguments and in the exact order.

I like that you can also test the called arguments but it seems that this assert_has_calls is order independent as the test passes when I shuffle calls inside or remove all to have callback.assert_has_calls([]) instead

UPDATE: sem that we can list all methods

assert callback_mock.method_calls == [
        call.on_init_start(trainer),
        call.on_init_end(trainer),
    ]
awaelchli commented 4 years ago

That's strange, because docs say it asserts the exact order in sequence: https://docs.python.org/3/library/unittest.mock.html#unittest.mock.Mock.assert_has_calls

Borda commented 4 years ago

That's strange, because docs say it asserts the exact order in sequence: https://docs.python.org/3/library/unittest.mock.html#unittest.mock.Mock.assert_has_calls

maybe some bug in implementation but is you simply test the replacement with an empty list, it passes too

awaelchli commented 4 years ago

yes it makes sense that it passes with empty list because any number of calls can occur before or after the sequence you pass in, so for example,


before()
a()
b()
c()
after()

assert_has_calls([]) # true
assert_has_calls([a]) # true
assert_has_calls([a, b]) # true
assert_has_calls([a, b, c]) # true
assert_has_calls([b, a, c]) # FALSE!!
assert_has_calls([before, a, b, c, after]) # true
awaelchli commented 4 years ago

Thanks for taking care of this @Borda !

Borda commented 4 years ago

yes it makes sense that it passes with an empty list because any number of calls can occur before or after the sequence you pass in, so for example,

I see, it is a bit dangerous in case you miss some at the beginning or at the end and all seems to be fine... 8-)