Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.47k stars 3.39k forks source link

[RFC] Support a `Trainer.train()` API #10888

Closed ananthsub closed 2 years ago

ananthsub commented 2 years ago

🚀 Feature

Add a new entry point to the Trainer which runs only the training loop with no validation.

Motivation

This makes it clear that if users only define training_step and train_dataloader then they can call train without any risk of errors due to not implementing validation hooks. Though the framework checks this today.

Another motivation is that users who do implement validation steps/dataloaders may only want to run training without validation. (for example, in the case of online training). Today, those users would need to ensure they set limit_val_batches=0 before calling trainer.fit

Finally, such a feature makes it easier to interleave train/validate/test/predict calls. For example, past requests have been made to run the test loop after each validation pass. In conjunction with #10444 this makes writing more complex patterns far simpler with Lightning.

This is slightly different from loop customization. In this case, I don't want to change any of the fundamental building blocks, but I may want to change the order/sequencing in which they're called.

t = Trainer(...)
m = MyLightningModule(...)

def my_custom_interleaving(t: Trainer, m: LightningModule):
    t.train(m, dry_run=True)  # make sure everything runs okay, fail fast if not
    t.train(m, max_epochs=1)
    t.validate(m)
    t.test(m)
    t.predict(m)
    t.fit(m, max_epochs=10)
    t.validate(m)
    t.test(m)
    t.predict(m)

Pitch

Offer a top-level function on the Trainer:

def train(
    self,
    model: "pl.LightningModule",
    train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None,
    datamodule: Optional[LightningDataModule] = None,
    ckpt_path: Optional[str] = None,
 ) -> None:

Alternatives

One could try to work around this as follows:

t = Trainer(..., limit_val_batches=0)
t.fit(model)
t.limit_val_batches = 1.0
t.validate(model)
t.limit_val_batches = 0
...

However, this is somewhat clunky to write, and requires users to dig through the various trainer properties/attributes to reset state across calls, which is not straightforward.

Additional context


If you enjoy Lightning, check out our other projects! âš¡

cc @borda @justusschock @kaushikb11 @awaelchli @ananthsub @ninginthecloud @jjenniferdai @rohitgr7

carmocca commented 2 years ago

Do you think users might want to implement different training logic when they call trainer.fit vs trainer.train? If yes, do you think our current on_train_* hooks should have been named on_fit_*?

I guess trainer.fit validation and trainer.validate has the same problem anyways. Users can access trainer.state.fn == "validate" to differentiate.

ananthsub commented 2 years ago

Do you think users might want to implement different training logic when they call trainer.fit vs trainer.train? If yes, do you think our current ontrain hooks should have been named onfit?

No, I think on_train_* hooks are the right call, as train is a fundamental building block ( represented by RunningStage)https://github.com/PyTorchLightning/pytorch-lightning/blob/a28b4cd0c0bba30c21cae571e650877f66cf5588/pytorch_lightning/trainer/states.py#L56

and Trainer functions are higher-level compositions which can run 1+ RunningStages: https://github.com/PyTorchLightning/pytorch-lightning/blob/a28b4cd0c0bba30c21cae571e650877f66cf5588/pytorch_lightning/trainer/states.py#L34

as you note, if the user wants some logic to happen in validation hooks with trainer.fit but not with trainer.validate then the trainer provides access to the state (running stage, fn) to distinguish

carmocca commented 2 years ago

Some questions:

ananthsub commented 2 years ago

Would you expect that trainer.train shares the results and progress tracking state with trainer.fit?

No, I'd expect train to operate independently, the same way trainer.validate keeps track of its own state.

Would you ask for this feature if trainer.limit_val_batches was part of trainer.fit?

The need would be far less, but I think a dedicated entry point is clearer for users and provides greater confidence that the framework doesn't initialize or check anything related to validation (including validation sanity checks). fit with limit_val_batches=0 would functionally be the same thing though.

Does this addition mean trainer.fit should raise an exception when there's no validation?

I think we could keep the same behavior of not checking validation if limit_val_batches=0 and users calls trainer.fit. But by default, yes we could throw an exception.

Then users have a really clear path for onboarding:

One could argue that fit is best practice compared to train since one should always have training & validation data split out, but it's not like the framework enforces this today, given the checks if hooks are overridden and skipping those stages if they're not.

stale[bot] commented 2 years ago

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

Borda commented 2 years ago

Add a new entry point to the Trainer which runs only the training loop with no validation.

Could you please provide a model use case? Personally, I feel that it goes against the PL best practice principle, which in most cases includes progress monitoring and that shall be done on the validation set, not on training... Also, would you get the same functionally if you in Trainer's init limit number of validation bathes to zero?

carmocca commented 2 years ago

It's a sensible proposal but has the big drawback of creating confusion with trainer.fit. Especially because fit with no validation data will still be supported. If fitting with no training data had been supported, we might not have added trainer.validate

However, I see the advantage for external loop customization or orchestrating multiple calls.

Would you implement the loop class used as a copy of the FitLoop and its children with all validation logic removed or would you just disable the validation data and use the FitLoop?

tchaton commented 2 years ago

After reflecting on this, @ananthsub I believe this shouldn't be added.

First, because the Trainer API is final, but more importantly because it would force bad practices on the user. I am 100 % sure @williamFalcon and co at the beginning thought hard about it and the fact this option doesn't exist was meant to be from scratch.

IMO, the best practice induced by the trainer.fit default is to perform a sanity checking. Furthermore, opt-in out is quite simple but should be the responsibility of the user, e.g Trainer(limit_val_batches=0) or no validation_datalaoder.

+1 for added confusion.

@awaelchli @carmocca I would recommend to close this RFC.

Thanks @ananthsub for your time and effort proposing this :)

gzerveas commented 2 years ago

The existing solution e.g. Trainer(limit_val_batches=0) is sufficient, but if you are wondering about possible use cases for not performing validation, here is one: one first defines a custom train/validation split of the training set for optimizing hyperparameters, and once those are fixed, one wants to use the entire training set for training. Evaluation may happen on a separate test set at the end, or maybe it is not possible at all (e.g. hidden test set).

extragoya commented 1 year ago

The solution is indeed sufficient, although improved documentation would be welcome. For those interested, another use case is training models to reconstruct shapes or scenes, e.g., DeepSDF https://openaccess.thecvf.com/content_CVPR_2019/papers/Park_DeepSDF_Learning_Continuous_Signed_Distance_Functions_for_Shape_Representation_CVPR_2019_paper.pdf. I believe neural radiance fields would have a similar use case, and are very popular.