Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
27.95k stars 3.34k forks source link

Reload new data after one epoch doesnt work. #14093

Closed KralaBenjamin closed 2 years ago

KralaBenjamin commented 2 years ago

🐛 Bug

To Reproduce

https://colab.research.google.com/drive/1V0LDGDjW_Ettv0q6fQXFtnd-j2aIURDm?usp=sharing Just follow the boring model

Expected behavior

tldr: The trainer should use or reload the data from the class npDataModule. The trainer doesnt do it because it never reaches the breakpoints there.

Environment

Additional context

I would like to carry out Adversarial Defence Training. In the first epoch I load a data module and from the second epoch onwards I want to inject modified data into the training process at the beginning of each epoch.

I have three different approaches and they all don't work.

1.) The feature of the

reload_dataloaders_every_n_epochs

of the trainer class. According to the documentation I should use in my model (i.e. a LightningModule) the function
train_dataloader(), val_dataloader() etc., from which the trainer reloads the new dataloaders. This does not work. 2.) the usage of the functions

trainer.reset_train_dataloader()
trainer.reset_val_dataloader()

Here it is not clear what is being loaded from. (I find the documentation very sparse). It turns out that a new training_dataloader has been loaded of type CombinedLoader, which is not what I wanted. (val_dataloader remains None). 3) I manually set datamodule or train_dataloader etc. in the trainer object. This does not work either.

I am using the current version 1.7. I know if I'm missing something (I can't find anything in the documentation), but I have a simple workflow from a data perspective (just new data) and I can't find people commenting about similar problems. I'm not sure if it's a bug or I'm missing implicit assumptions. I would still appreciate help as I have been sitting on this problem for many work days and am very desperate. Thank you very much!

cc @justusschock @awaelchli @ninginthecloud @rohitgr7 @otaj

awaelchli commented 2 years ago

Hi @KralaBenjamin

You are passing dataloaders directly to fit, but then wish to exchange them later on with a datamodule. This is not supported and not recommended. I can sketch you a possible way to do this:


class YourDatamodule(pl.LightningDataModule):
    def __init__(self, ...):
        ...

        self._train_data = ...

    def train_dataloader(self):  
        ...
        return DataLoader(self._train_data, ...)

class AdversarialDefence(pl.Callback):
   ...

    def on_train_epoch_end(self, trainer, model):
        if trainer.current_epoch % self.attack_after_n_epochs == 0:
            new_data = ...
            trainer.datamodule._train_data = new_data

def run():

    model = BoringModel()
    datamodule = YourDatamodule()

    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=2,
        limit_val_batches=2,
        limit_test_batches=2,
        num_sanity_val_steps=0,
        max_epochs=5,
        enable_model_summary=False,
        reload_dataloaders_every_n_epochs=1,
        callbacks=[AdversarialDefence(0.5)]
    )
    trainer.fit(model, datamodule)  # <--- here pass in the datamodule

You can also define some helper methods in your DataModule to shift the responsibility of updating the data from the callback to the datamodule. That might help.

In the above solution, the Trainer can rely on calling the train_dataloader() method on the datamodule, which handles taking your data and configuring the DataLoader object. Then you can just focus on updating your dataset object from either the callback or from within the datamodule.

I haven't looked closely into your code, but I think it would also be possible to create your adversarial data samples online by replacing existing data on the fly with the new one.

KralaBenjamin commented 2 years ago

Thank you very much, it works now.