Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
27.48k stars 3.3k forks source link

LightningModule.train_dataloader not being called #19673

Open lsc64 opened 3 months ago

lsc64 commented 3 months ago

Bug description

The hook train_dataloader of LightningModule is not being called from Trainer.fit. I need to put code there, that changes the dataloader and requires access to the optimizers, as follows

class Classifier(LightningModule):
    def __init__(
        self,
        *args,
        **kwargs,
    ):
        super().__init__()
        # model initalized here

    def train_dataloader(self) -> Any:
        dl = self.trainer.datamodule.train_dataloader()
        if not hasattr(self.trainer.datamodule, "batch_size_physical"):
            return dl # just use the LightningDataModule as is
        # wrap using this function otherwise
        return wrap_data_loader(
            data_loader=dl,
            max_batch_size=self.trainer.datamodule.batch_size_physical,
            optimizer=self.optimizer,
        )

What version are you seeing the problem on?

v2.1, v2.2

How to reproduce the bug

run the following code. It should print Hello from train_dataloader in the LightningModule if the function is being called.


import os

import torch
from lightning.pytorch import LightningDataModule, LightningModule, Trainer
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST

class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(28, 2)

    def forward(self, batch):
        x, y = batch
        return self.layer(x)

    def train_dataloader(self):
        print("Hello from train_dataloader in the LightningModule")
        return super().train_dataloader()

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

class MNISTDataModule(LightningDataModule):
    def __init__(self, data_dir: str = "./"):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        )

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage: str):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit":
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(
                mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
            )

        # Assign test dataset for use in dataloader(s)
        if stage == "test":
            self.mnist_test = MNIST(
                self.data_dir, train=False, transform=self.transform
            )

        if stage == "predict":
            self.mnist_predict = MNIST(
                self.data_dir, train=False, transform=self.transform
            )

    def train_dataloader(self):
        print("Hello from train_dataloader in the LightningDataModule")
        return DataLoader(self.mnist_train, batch_size=32)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=32)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=32)

    def predict_dataloader(self):
        return DataLoader(self.mnist_predict, batch_size=32)

def main():
    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        devices=1,
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        enable_model_summary=False,
    )
    datamodule = MNISTDataModule()
    trainer.fit(model, datamodule=datamodule)

if __name__ == "__main__":
    main()

Error messages and logs

 python boring_snippet.py
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1)` was configured so 1 batch per epoch will be used.
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
`Trainer(limit_test_batches=1)` was configured so 1 batch will be used.
You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
Hello from train_dataloader in the LightningDataModule
/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=255` in the `DataLoader` to improve performance.
/opt/conda/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=255` in the `DataLoader` to improve performance.
Epoch 0: 100%|_____________________________________________________________________________________________________________________________________________| 1/1 [00:00<00:00,  7.82it/s, v_num=49]
`Trainer.fit` stopped: `max_epochs=1` reached.
Epoch 0: 100%|_____________________________________________________________________________________________________________________________________________| 1/1 [00:00<00:00,  7.67it/s, v_num=49]

Environment

Current environment * CUDA: - GPU: - NVIDIA A100-SXM4-40GB - NVIDIA A100-SXM4-40GB - NVIDIA A100-SXM4-40GB - NVIDIA A100-SXM4-40GB - NVIDIA A100-SXM4-40GB - NVIDIA A100-SXM4-40GB - NVIDIA A100-SXM4-40GB - NVIDIA A100-SXM4-40GB - available: True - version: 12.1 * Lightning: - lightning: 2.2.1 - lightning-bolts: 0.7.0 - lightning-utilities: 0.11.0 - pytorch-lightning: 2.2.1 - torch: 2.2.1 - torchaudio: 2.2.1 - torchmetrics: 1.3.2 - torchvision: 0.17.1 * Packages: - absl-py: 2.1.0 - aiohttp: 3.9.3 - aiohttp-cors: 0.7.0 - aiosignal: 1.3.1 - alembic: 1.13.1 - aniso8601: 9.0.1 - annotated-types: 0.6.0 - antlr4-python3-runtime: 4.9.3 - anyio: 4.3.0 - archspec: 0.2.2 - argon2-cffi: 23.1.0 - argon2-cffi-bindings: 21.2.0 - arrow: 1.3.0 - asciitree: 0.3.3 - asttokens: 2.4.1 - async-lru: 2.0.4 - attrs: 23.2.0 - autodp: 0.2.3.1 - babel: 2.14.0 - bcrypt: 4.1.2 - beautifulsoup4: 4.12.3 - bleach: 6.1.0 - blessed: 1.19.1 - blinker: 1.7.0 - boltons: 23.1.1 - brotli: 1.1.0 - cached-property: 1.5.2 - cachetools: 5.3.3 - certifi: 2024.2.2 - cffi: 1.16.0 - cfgv: 3.4.0 - chardet: 5.2.0 - charset-normalizer: 3.3.2 - chex: 0.1.85 - click: 8.1.7 - cloudpickle: 3.0.0 - colorama: 0.4.6 - colorful: 0.5.6 - colorlog: 6.8.2 - comm: 0.2.2 - conda: 24.1.2 - conda-build: 24.1.2 - conda-index: 0.4.0 - conda-libmamba-solver: 23.12.0 - conda-package-handling: 2.2.0 - conda-package-streaming: 0.9.0 - contourpy: 1.2.0 - cryptography: 42.0.5 - cycler: 0.12.1 - dask: 2024.2.1 - debugpy: 1.8.1 - decorator: 5.1.1 - defusedxml: 0.7.1 - diffprivlib: 0.6.4 - distlib: 0.3.8 - distro: 1.8.0 - dm-tree: 0.1.8 - docker: 7.0.0 - dp-learning-ff: 0.0.9.dev23+g5b7d4b5.d20240319 - entrypoints: 0.4 - equinox: 0.11.3 - etils: 1.7.0 - exceptiongroup: 1.2.0 - executing: 2.0.1 - fasteners: 0.17.3 - fastjsonschema: 2.19.1 - filelock: 3.13.1 - flask: 3.0.2 - flax: 0.8.2 - fonttools: 4.49.0 - fqdn: 1.5.1 - frozenlist: 1.4.1 - fsspec: 2024.3.1 - gast: 0.5.4 - gitdb: 4.0.11 - gitpython: 3.1.42 - gmpy2: 2.1.2 - google-api-core: 2.17.1 - google-auth: 2.28.1 - google-vizier: 0.1.15 - googleapis-common-protos: 1.62.0 - gpustat: 1.1.1 - graphene: 3.3 - graphql-core: 3.2.3 - graphql-relay: 3.2.0 - greenlet: 3.0.3 - grpcio: 1.62.1 - grpcio-tools: 1.62.1 - gunicorn: 21.2.0 - h11: 0.14.0 - h5py: 3.10.0 - httpcore: 1.0.4 - httpx: 0.27.0 - huggingface-hub: 0.21.4 - hydra-core: 1.3.2 - identify: 2.5.35 - idna: 3.6 - importlib-metadata: 7.0.1 - importlib-resources: 6.1.2 - ipykernel: 6.29.3 - ipython: 8.22.2 - ipywidgets: 8.1.2 - isoduration: 20.11.0 - itsdangerous: 2.1.2 - jax: 0.4.25 - jaxlib: 0.4.25 - jaxopt: 0.8.3 - jaxtyping: 0.2.28 - jedi: 0.19.1 - jinja2: 3.1.3 - joblib: 1.3.2 - json5: 0.9.24 - jsonpatch: 1.33 - jsonpointer: 2.4 - jsonschema: 4.21.1 - jsonschema-specifications: 2023.12.1 - jupyter-client: 8.6.0 - jupyter-core: 5.7.1 - jupyter-events: 0.9.0 - jupyter-lsp: 2.2.4 - jupyter-server: 2.13.0 - jupyter-server-mathjax: 0.2.6 - jupyter-server-terminals: 0.5.2 - jupyterlab: 4.1.5 - jupyterlab-git: 0.50.0 - jupyterlab-pygments: 0.3.0 - jupyterlab-server: 2.25.4 - jupyterlab-widgets: 3.0.10 - kiwisolver: 1.4.5 - libarchive-c: 5.0 - libmambapy: 1.5.7 - lightning: 2.2.1 - lightning-bolts: 0.7.0 - lightning-utilities: 0.11.0 - locket: 1.0.0 - mako: 1.3.2 - mamba: 1.5.7 - markdown: 3.5.2 - markdown-it-py: 3.0.0 - markupsafe: 2.1.5 - matplotlib: 3.8.3 - matplotlib-inline: 0.1.6 - mdurl: 0.1.2 - memory-tempfile: 2.2.3 - menuinst: 2.0.2 - mistune: 3.0.2 - ml-dtypes: 0.3.2 - mlflow: 2.11.0 - mlflow-skinny: 2.11.0 - more-itertools: 10.2.0 - mpmath: 1.3.0 - msgpack: 1.0.7 - multidict: 6.0.5 - munkres: 1.1.4 - nbclient: 0.8.0 - nbconvert: 7.16.2 - nbdime: 4.0.1 - nbformat: 5.9.2 - nest-asyncio: 1.6.0 - networkx: 3.2.1 - nodeenv: 1.8.0 - notebook-shim: 0.2.4 - numcodecs: 0.12.1 - numpy: 1.26.4 - nvidia-cublas-cu12: 12.1.3.1 - nvidia-cuda-cupti-cu12: 12.1.105 - nvidia-cuda-nvrtc-cu12: 12.1.105 - nvidia-cuda-runtime-cu12: 12.1.105 - nvidia-cudnn-cu12: 8.9.2.26 - nvidia-cufft-cu12: 11.0.2.54 - nvidia-curand-cu12: 10.3.2.106 - nvidia-cusolver-cu12: 11.4.5.107 - nvidia-cusparse-cu12: 12.1.0.106 - nvidia-ml-py: 12.535.133 - nvidia-nccl-cu12: 2.19.3 - nvidia-nvjitlink-cu12: 12.4.99 - nvidia-nvtx-cu12: 12.1.105 - omegaconf: 2.3.0 - opacus: 1.4.1 - opencensus: 0.11.4 - opencensus-context: 0.1.3 - opt-einsum: 3.3.0 - optax: 0.2.1 - optuna: 3.5.0 - orbax-checkpoint: 0.5.6 - overrides: 7.7.0 - packaging: 24.0 - pandas: 2.2.1 - pandocfilters: 1.5.0 - paramiko: 3.4.0 - parso: 0.8.3 - partd: 1.4.1 - pexpect: 4.9.0 - pickleshare: 0.7.5 - pillow: 9.4.0 - pip: 23.3.2 - pkginfo: 1.10.0 - pkgutil-resolve-name: 1.3.10 - platformdirs: 4.1.0 - pluggy: 1.3.0 - portpicker: 1.6.0 - pre-commit: 3.6.2 - prometheus-client: 0.20.0 - prometheus-flask-exporter: 0.23.0 - prompt-toolkit: 3.0.42 - protobuf: 4.24.4 - psutil: 5.9.8 - psycopg2-binary: 2.9.9 - ptyprocess: 0.7.0 - pure-eval: 0.2.2 - py-spy: 0.3.14 - pyarrow: 15.0.0 - pyasn1: 0.5.1 - pyasn1-modules: 0.3.0 - pycosat: 0.6.6 - pycparser: 2.21 - pydantic: 2.6.3 - pydantic-core: 2.16.3 - pygments: 2.17.2 - pynacl: 1.5.0 - pyparsing: 3.1.2 - pysocks: 1.7.1 - python-dateutil: 2.9.0 - python-dp: 1.1.4 - python-json-logger: 2.0.7 - pytorch-lightning: 2.2.1 - pytz: 2024.1 - pyyaml: 6.0.1 - pyzmq: 25.1.2 - querystring-parser: 1.2.4 - ray: 2.9.3 - referencing: 0.33.0 - regex: 2023.12.25 - requests: 2.31.0 - rfc3339-validator: 0.1.4 - rfc3986-validator: 0.1.1 - rich: 13.7.1 - rpds-py: 0.18.0 - rsa: 4.9 - ruamel.yaml: 0.18.6 - ruamel.yaml.clib: 0.2.8 - ruff: 0.3.3 - safetensors: 0.4.2 - scikit-learn: 1.4.1.post1 - scipy: 1.12.0 - seaborn: 0.13.2 - send2trash: 1.8.2 - setuptools: 68.2.2 - six: 1.16.0 - skorch: 0.15.0 - smart-open: 7.0.1 - smmap: 5.0.0 - sniffio: 1.3.1 - soupsieve: 2.5 - sqlalchemy: 2.0.28 - sqlparse: 0.4.4 - stack-data: 0.6.2 - sympy: 1.12 - tabulate: 0.9.0 - tensorboard: 2.16.2 - tensorboard-data-server: 0.7.0 - tensorstore: 0.1.56 - terminado: 0.18.0 - tfp-nightly: 0.25.0.dev20240318 - threadpoolctl: 3.3.0 - timm: 0.9.16 - tinycss2: 1.2.1 - tokenizers: 0.15.2 - toolz: 0.12.1 - torch: 2.2.1 - torchaudio: 2.2.1 - torchmetrics: 1.3.2 - torchvision: 0.17.1 - tornado: 6.4 - tqdm: 4.66.2 - traitlets: 5.14.1 - transformers: 4.38.2 - triton: 2.2.0 - truststore: 0.8.0 - typeguard: 2.13.3 - types-python-dateutil: 2.8.19.20240106 - typing-extensions: 4.10.0 - typing-utils: 0.1.0 - tzdata: 2024.1 - uri-template: 1.3.0 - urllib3: 2.1.0 - uv: 0.1.22 - virtualenv: 20.25.1 - vit-proto: 0.0.0 - wcwidth: 0.2.13 - webcolors: 1.13 - webencodings: 0.5.1 - websocket-client: 1.7.0 - werkzeug: 3.0.1 - wheel: 0.42.0 - widgetsnbextension: 4.0.10 - wrapt: 1.16.0 - yarl: 1.9.4 - zarr: 2.17.1 - zipp: 3.17.0 - zstandard: 0.22.0 * System: - OS: Linux - architecture: - 64bit - - processor: x86_64 - python: 3.11.8 - release: 5.4.0-173-generic - version: #191-Ubuntu SMP Fri Feb 2 13:55:07 UTC 2024

More info

No response

cc @carmocca @awaelchli @borda

awaelchli commented 3 months ago

Hi @lsc64

As I explained on the forum, if you use data module then the methods on the LightningModule won't be called. This is by design and expected. If you have a dependency on model-specific attributes, then LightningDataModule is not a good choice. It should only be used if you can decouple the data definitions from your model.

In your case, implement the LightningModule methods instead:

class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(28, 2)

    def forward(self, batch):
        x, y = batch
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage: str):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit":
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(
                mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
            )

        # Assign test dataset for use in dataloader(s)
        if stage == "test":
            self.mnist_test = MNIST(
                self.data_dir, train=False, transform=self.transform
            )

        if stage == "predict":
            self.mnist_predict = MNIST(
                self.data_dir, train=False, transform=self.transform
            )

    def train_dataloader(self) -> Any:
        # wrap using this function otherwise
        return wrap_data_loader(
            data_loader=dl,
            max_batch_size=self.batch_size_physical,
            optimizer=self.optimizer,
        )
lsc64 commented 3 months ago

hi @awaelchli

This is by design and expected

Should we document this somewhere? Right now the docs state:

LightningDataModule

Use the train_dataloader() method to generate the training dataloader(s). Usually you just wrap the dataset you defined in setup. This is the dataloader that the Trainer fit() method uses.

LightningModule

An iterable or collection of iterables specifying training samples.

But both can be used by trainer.fit, it's just that one overwrites the other.

I've solved my problem by storing the optimizer and dealing with this outside of lightning, for anyone else having this problem:

train_dataloader = dm.train_dataloader()
model.configure_optimizers(train_dataloader)
wrapped_dl = wrap_data_loader(
    data_loader=train_dataloader, max_batch_size=4096, optimizer=model.optimizer
)
trainer.fit(
    model, train_dataloaders=wrapped_dl, val_dataloaders=dm.val_dataloader()
)

You just have to make sure the optimizer doesn't get configured again in your LightningModule

def configure_optimizers(self, data_loader=None):
    if hasattr(self, "optimizer"):
        return self.optimizer