Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.09k stars 3.36k forks source link

train_dataloader not recognized in Data Module #18808

Closed jscottcronin closed 11 months ago

jscottcronin commented 11 months ago

Bug description

Following the most basic example for MNIST with a multilayer model. Does not run properly on my Mac M1 laptop.

from typing import Any
import lightning.pytorch as pl
from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRScheduler
from torch import optim, nn
import torch.nn.functional as F
import torch
from torch.utils.data import random_split, DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from typing import Optional

class MNISTModel(pl.LightningModule):
    def __init__(self, learning_rate=1e-3):
        super().__init__()
        self.learning_rate = learning_rate

        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 256),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 10),
            nn.ReLU(),
        )

    def cross_entropy_loss(self, logits, labels):
        return F.nll_loss(logits, labels)

    def forward(self, x):
        x = self.model(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        # ys = torch.zeros(x.size(0), 10)
        # ys[y] = 1
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        # ys = torch.zeros(x.size(0), 10)
        # ys[y] = 1
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        self.log("val_loss", loss)
        return loss

    def configure_optimizers(self) -> OptimizerLRScheduler:
        optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=32) -> None:
        super().__init__()

        self.data_dir = "./data"
        self.num_classes = 10
        self.dims = (1, 28, 28)
        self.batch_size = batch_size

        self.transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        )

        def prepare_data(self) -> None:
            MNIST(root="./data", train=True, transform=self.transform, download=True)
            MNIST(root="./data", train=False, transform=self.transform, download=True)

        def setup(self, stage: Optional[str] = None) -> None:
            if stage == "fit" or stage is None:
                full = MNIST(root=self.data_dir, train=True, transform=self.transform)
                self.train, self.val = random_split(full, [55000, 5000])

            if stage == "test" or stage is None:
                self.test = MNIST(
                    root=self.data_dir, train=False, transform=self.transform
                )

        def train_dataloader(self) -> DataLoader:
            return DataLoader(self.train, batch_size=self.batch_size)

        def val_dataloader(self) -> DataLoader:
            return DataLoader(self.val, batch_size=self.batch_size)

        def test_dataloader(self) -> DataLoader:
            return DataLoader(self.test, batch_size=self.batch_size)

if __name__ == "__main__":
    model = MNISTModel()
    dm = MNISTDataModule()
    trainer = pl.Trainer(fast_dev_run=True, max_epochs=10)
    trainer.fit(model, datamodule=dm)
    trainer.test(model, datamodule=dm)

What version are you seeing the problem on?

v2.0

How to reproduce the bug

Just run the code:

from typing import Any
import lightning.pytorch as pl
from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRScheduler
from torch import optim, nn
import torch.nn.functional as F
import torch
from torch.utils.data import random_split, DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from typing import Optional

class MNISTModel(pl.LightningModule):
    def __init__(self, learning_rate=1e-3):
        super().__init__()
        self.learning_rate = learning_rate

        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 256),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 10),
            nn.ReLU(),
        )

    def cross_entropy_loss(self, logits, labels):
        return F.nll_loss(logits, labels)

    def forward(self, x):
        x = self.model(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        # ys = torch.zeros(x.size(0), 10)
        # ys[y] = 1
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        # ys = torch.zeros(x.size(0), 10)
        # ys[y] = 1
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        self.log("val_loss", loss)
        return loss

    def configure_optimizers(self) -> OptimizerLRScheduler:
        optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=32) -> None:
        super().__init__()

        self.data_dir = "./data"
        self.num_classes = 10
        self.dims = (1, 28, 28)
        self.batch_size = batch_size

        self.transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        )

        def prepare_data(self) -> None:
            MNIST(root="./data", train=True, transform=self.transform, download=True)
            MNIST(root="./data", train=False, transform=self.transform, download=True)

        def setup(self, stage: Optional[str] = None) -> None:
            if stage == "fit" or stage is None:
                full = MNIST(root=self.data_dir, train=True, transform=self.transform)
                self.train, self.val = random_split(full, [55000, 5000])

            if stage == "test" or stage is None:
                self.test = MNIST(
                    root=self.data_dir, train=False, transform=self.transform
                )

        def train_dataloader(self) -> DataLoader:
            return DataLoader(self.train, batch_size=self.batch_size)

        def val_dataloader(self) -> DataLoader:
            return DataLoader(self.val, batch_size=self.batch_size)

        def test_dataloader(self) -> DataLoader:
            return DataLoader(self.test, batch_size=self.batch_size)

if __name__ == "__main__":
    model = MNISTModel()
    dm = MNISTDataModule()
    trainer = pl.Trainer(fast_dev_run=True, max_epochs=10)
    trainer.fit(model, datamodule=dm)
    trainer.test(model, datamodule=dm)

Error messages and logs

Backend MacOSX is interactive backend. Turning interactive mode on.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.
/Users/scottcronin/.pyenv/versions/3.11.2/lib/python3.11/site-packages/lightning/pytorch/trainer/configuration_validator.py:71: PossibleUserWarning: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
  rank_zero_warn(

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 235 K 
-------------------------------------
235 K     Trainable params
0         Non-trainable params
235 K     Total params
0.941     Total estimated model params size (MB)
Traceback (most recent call last):
  File "/Users/scottcronin/.pyenv/versions/3.11.2/lib/python3.11/runpy.py", line 198, in _run_module_as_main
    return _run_code(code, main_globals, None,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/scottcronin/.pyenv/versions/3.11.2/lib/python3.11/runpy.py", line 88, in _run_code
    exec(code, run_globals)
  File "/Users/scottcronin/.vscode/extensions/ms-python.python-2023.18.0/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 39, in <module>
    cli.main()
  File "/Users/scottcronin/.vscode/extensions/ms-python.python-2023.18.0/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main
    run()
  File "/Users/scottcronin/.vscode/extensions/ms-python.python-2023.18.0/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file
    runpy.run_path(target, run_name="__main__")
  File "/Users/scottcronin/.vscode/extensions/ms-python.python-2023.18.0/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path
    return _run_module_code(code, init_globals, run_name,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/scottcronin/.vscode/extensions/ms-python.python-2023.18.0/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code
    _run_code(code, mod_globals, init_globals,
  File "/Users/scottcronin/.vscode/extensions/ms-python.python-2023.18.0/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code
    exec(code, run_globals)
  File "/Users/scottcronin/git/refresher/s04_algorithms/t05_dl_linear_regression/multilayer.py", line 101, in <module>
    trainer.fit(model, dm)
  File "/Users/scottcronin/.pyenv/versions/3.11.2/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 532, in fit
    call._call_and_handle_interrupt(
  File "/Users/scottcronin/.pyenv/versions/3.11.2/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 43, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/scottcronin/.pyenv/versions/3.11.2/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 571, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/Users/scottcronin/.pyenv/versions/3.11.2/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 980, in _run
    results = self._run_stage()
              ^^^^^^^^^^^^^^^^^
  File "/Users/scottcronin/.pyenv/versions/3.11.2/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1023, in _run_stage
    self.fit_loop.run()
  File "/Users/scottcronin/.pyenv/versions/3.11.2/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py", line 194, in run
    self.setup_data()
  File "/Users/scottcronin/.pyenv/versions/3.11.2/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py", line 222, in setup_data
    train_dataloader = _request_dataloader(source)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/scottcronin/.pyenv/versions/3.11.2/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py", line 336, in _request_dataloader
    return data_source.dataloader()
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/scottcronin/.pyenv/versions/3.11.2/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py", line 303, in dataloader
    return call._call_lightning_datamodule_hook(self.instance.trainer, self.name)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/scottcronin/.pyenv/versions/3.11.2/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 166, in _call_lightning_datamodule_hook
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/scottcronin/.pyenv/versions/3.11.2/lib/python3.11/site-packages/lightning/pytorch/core/hooks.py", line 432, in train_dataloader
    raise MisconfigurationException("`train_dataloader` must be implemented to be used with the Lightning Trainer")
lightning.fabric.utilities.exceptions.MisconfigurationException: `train_dataloader` must be implemented to be used with the Lightning Trainer

Environment

<details>
  <summary>Current environment</summary>

* CUDA:
        - GPU:               None
        - available:         False
        - version:           None
* Lightning:
        - lightning:         2.1.0
        - lightning-cloud:   0.5.39
        - lightning-utilities: 0.9.0
        - pytorch-lightning: 2.1.0
        - torch:             2.1.0
        - torch-tb-profiler: 0.4.3
        - torchinfo:         1.8.0
        - torchmetrics:      1.2.0
        - torchvision:       0.16.0
* Packages:
        - absl-py:           2.0.0
        - aiofiles:          22.1.0
        - aiohttp:           3.8.4
        - aiosignal:         1.3.1
        - aiosqlite:         0.18.0
        - altair:            4.2.2
        - annotated-types:   0.6.0
        - anyio:             3.7.1
        - appnope:           0.1.3
        - argon2-cffi:       21.3.0
        - argon2-cffi-bindings: 21.2.0
        - arrow:             1.2.3
        - astor:             0.8.1
        - asttokens:         2.2.1
        - async-timeout:     4.0.2
        - attrs:             22.2.0
        - autograd:          1.6.2
        - autograd-gamma:    0.5.0
        - babel:             2.12.1
        - backcall:          0.2.0
        - backoff:           2.2.1
        - beautifulsoup4:    4.11.2
        - black:             23.9.1
        - bleach:            6.0.0
        - blessed:           1.20.0
        - blinker:           1.6.2
        - build:             0.10.0
        - cachecontrol:      0.13.1
        - cachetools:        5.3.0
        - certifi:           2022.12.7
        - cffi:              1.15.1
        - charset-normalizer: 3.1.0
        - cleo:              2.0.1
        - click:             8.1.3
        - comm:              0.1.2
        - contourpy:         1.0.7
        - crashtest:         0.4.1
        - croniter:          1.4.1
        - cycler:            0.11.0
        - dateutils:         0.6.12
        - debugpy:           1.6.6
        - decorator:         5.1.1
        - deepdiff:          6.6.0
        - defusedxml:        0.7.1
        - distlib:           0.3.7
        - dulwich:           0.21.6
        - entrypoints:       0.4
        - executing:         1.2.0
        - fastapi:           0.103.2
        - fastjsonschema:    2.16.3
        - filelock:          3.12.3
        - fonttools:         4.39.2
        - formulaic:         0.6.6
        - fqdn:              1.5.1
        - frozenlist:        1.3.3
        - fsspec:            2023.9.2
        - future:            0.18.3
        - gitdb:             4.0.10
        - gitpython:         3.1.31
        - google-auth:       2.23.3
        - google-auth-oauthlib: 1.0.0
        - grpcio:            1.59.0
        - h11:               0.14.0
        - idna:              3.4
        - importlib-metadata: 6.6.0
        - inflection:        0.5.1
        - iniconfig:         2.0.0
        - inquirer:          3.1.3
        - installer:         0.7.0
        - interface-meta:    1.3.0
        - ipykernel:         6.21.3
        - ipython:           8.11.0
        - ipython-genutils:  0.2.0
        - ipywidgets:        8.0.4
        - islp:              0.3.21
        - isoduration:       20.11.0
        - itsdangerous:      2.1.2
        - jaraco.classes:    3.3.0
        - jedi:              0.18.2
        - jinja2:            3.1.2
        - joblib:            1.3.2
        - json5:             0.9.11
        - jsonpointer:       2.3
        - jsonschema:        4.17.3
        - jupyter-client:    8.0.3
        - jupyter-core:      5.3.0
        - jupyter-events:    0.6.3
        - jupyter-server:    2.5.0
        - jupyter-server-fileid: 0.8.0
        - jupyter-server-terminals: 0.4.4
        - jupyter-server-ydoc: 0.6.1
        - jupyter-ydoc:      0.2.3
        - jupyterlab:        3.6.1
        - jupyterlab-pygments: 0.2.2
        - jupyterlab-server: 2.20.0
        - jupyterlab-widgets: 3.0.5
        - keyring:           24.2.0
        - kiwisolver:        1.4.4
        - lifelines:         0.27.8
        - lightning:         2.1.0
        - lightning-cloud:   0.5.39
        - lightning-utilities: 0.9.0
        - lxml:              4.9.3
        - markdown:          3.5
        - markdown-it-py:    2.2.0
        - markupsafe:        2.1.2
        - matplotlib:        3.8.0
        - matplotlib-inline: 0.1.6
        - mdurl:             0.1.2
        - mistune:           2.0.5
        - more-itertools:    10.1.0
        - mpmath:            1.3.0
        - msgpack:           1.0.6
        - multidict:         6.0.4
        - mypy-extensions:   1.0.0
        - nbclassic:         0.5.3
        - nbclient:          0.7.2
        - nbconvert:         7.2.10
        - nbformat:          5.7.3
        - nest-asyncio:      1.5.6
        - networkx:          3.1
        - nose:              1.3.7
        - notebook:          6.5.3
        - notebook-shim:     0.2.2
        - numpy:             1.24.4
        - oauthlib:          3.2.2
        - openai:            0.27.8
        - openapi:           1.1.0
        - ordered-set:       4.1.0
        - packaging:         23.0
        - pandas:            1.5.3
        - pandocfilters:     1.5.0
        - parso:             0.8.3
        - pathlib:           1.0.1
        - pathspec:          0.11.2
        - patsy:             0.5.3
        - pexpect:           4.8.0
        - pickleshare:       0.7.5
        - pillow:            9.5.0
        - pip:               23.3
        - pkginfo:           1.9.6
        - platformdirs:      3.10.0
        - pluggy:            1.3.0
        - poetry:            1.6.1
        - poetry-core:       1.7.0
        - poetry-plugin-export: 1.5.0
        - progressbar2:      4.2.0
        - prometheus-client: 0.16.0
        - prompt-toolkit:    3.0.38
        - protobuf:          3.20.3
        - psutil:            5.9.4
        - ptyprocess:        0.7.0
        - pure-eval:         0.2.2
        - pyarrow:           12.0.0
        - pyasn1:            0.5.0
        - pyasn1-modules:    0.3.0
        - pycparser:         2.21
        - pydantic:          2.1.1
        - pydantic-core:     2.4.0
        - pydeck:            0.8.1b0
        - pygam:             0.9.0
        - pygments:          2.14.0
        - pyjwt:             2.8.0
        - pympler:           1.0.1
        - pyparsing:         3.0.9
        - pyproject-hooks:   1.0.0
        - pyqt6:             6.4.2
        - pyqt6-qt6:         6.4.2
        - pyqt6-sip:         13.4.1
        - pyrsistent:        0.19.3
        - pytest:            7.4.2
        - python-dateutil:   2.8.2
        - python-editor:     1.0.4
        - python-json-logger: 2.0.7
        - python-multipart:  0.0.6
        - python-utils:      3.8.1
        - pytorch-lightning: 2.1.0
        - pytz:              2022.7.1
        - pyyaml:            6.0
        - pyzmq:             25.0.1
        - qgrid:             1.3.1
        - rapidfuzz:         2.15.1
        - readchar:          4.0.5
        - requests:          2.31.0
        - requests-oauthlib: 1.3.1
        - requests-toolbelt: 1.0.0
        - rfc3339-validator: 0.1.4
        - rfc3986-validator: 0.1.1
        - rich:              13.3.5
        - rsa:               4.9
        - scikit-learn:      1.3.1
        - scipy:             1.11.2
        - seaborn:           0.12.2
        - send2trash:        1.8.0
        - setuptools:        65.5.0
        - shellingham:       1.5.3
        - six:               1.16.0
        - smmap:             5.0.0
        - sniffio:           1.3.0
        - soupsieve:         2.4
        - stack-data:        0.6.2
        - starlette:         0.27.0
        - starsessions:      1.3.0
        - statsmodels:       0.14.0
        - streamlit:         1.22.0
        - sympy:             1.12
        - tenacity:          8.2.2
        - tensorboard:       2.14.1
        - tensorboard-data-server: 0.7.1
        - terminado:         0.17.1
        - threadpoolctl:     3.2.0
        - tinycss2:          1.2.1
        - toml:              0.10.2
        - tomlkit:           0.12.1
        - toolz:             0.12.0
        - torch:             2.1.0
        - torch-tb-profiler: 0.4.3
        - torchinfo:         1.8.0
        - torchmetrics:      1.2.0
        - torchvision:       0.16.0
        - tornado:           6.2
        - tqdm:              4.65.0
        - traitlets:         5.9.0
        - trove-classifiers: 2023.9.19
        - typing-extensions: 4.6.2
        - tzdata:            2023.3
        - tzlocal:           5.0.1
        - uri-template:      1.2.0
        - urllib3:           1.26.15
        - uvicorn:           0.23.2
        - validators:        0.20.0
        - virtualenv:        20.24.5
        - wcwidth:           0.2.6
        - webcolors:         1.12
        - webencodings:      0.5.1
        - websocket-client:  1.5.1
        - websockets:        11.0.3
        - werkzeug:          3.0.0
        - widgetsnbextension: 4.0.5
        - wrapt:             1.15.0
        - xattr:             0.10.1
        - y-py:              0.5.9
        - yarl:              1.9.2
        - ypy-websocket:     0.8.2
        - zipp:              3.15.0
* System:
        - OS:                Darwin
        - architecture:
                - 64bit
                - 
        - processor:         arm
        - python:            3.11.2
        - release:           22.3.0
        - version:           Darwin Kernel Version 22.3.0: Mon Jan 30 20:38:37 PST 2023; root:xnu-8792.81.3~2/RELEASE_ARM64_T6000

</details>

More info

No response

awaelchli commented 11 months ago

@jscottcronin All your methods are indented under the __init__ method. You should unindent them one tab. Here's the fixed code:

from typing import Any
import lightning.pytorch as pl
from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRScheduler
from torch import optim, nn
import torch.nn.functional as F
import torch
from torch.utils.data import random_split, DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from typing import Optional

class MNISTModel(pl.LightningModule):
    def __init__(self, learning_rate=1e-3):
        super().__init__()
        self.learning_rate = learning_rate

        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 256),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 10),
            nn.ReLU(),
        )

    def cross_entropy_loss(self, logits, labels):
        return F.nll_loss(logits, labels)

    def forward(self, x):
        x = self.model(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        # ys = torch.zeros(x.size(0), 10)
        # ys[y] = 1
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        # ys = torch.zeros(x.size(0), 10)
        # ys[y] = 1
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        self.log("val_loss", loss)
        return loss

    def configure_optimizers(self) -> OptimizerLRScheduler:
        optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=32) -> None:
        super().__init__()

        self.data_dir = "./data"
        self.num_classes = 10
        self.dims = (1, 28, 28)
        self.batch_size = batch_size

        self.transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        )

    def prepare_data(self) -> None:
        MNIST(root="./data", train=True, transform=self.transform, download=True)
        MNIST(root="./data", train=False, transform=self.transform, download=True)

    def setup(self, stage: Optional[str] = None) -> None:
        if stage == "fit" or stage is None:
            full = MNIST(root=self.data_dir, train=True, transform=self.transform)
            self.train, self.val = random_split(full, [55000, 5000])

        if stage == "test" or stage is None:
            self.test = MNIST(
                root=self.data_dir, train=False, transform=self.transform
            )

    def train_dataloader(self) -> DataLoader:
        return DataLoader(self.train, batch_size=self.batch_size)

    def val_dataloader(self) -> DataLoader:
        return DataLoader(self.val, batch_size=self.batch_size)

    def test_dataloader(self) -> DataLoader:
        return DataLoader(self.test, batch_size=self.batch_size)

if __name__ == "__main__":
    model = MNISTModel()
    dm = MNISTDataModule()
    trainer = pl.Trainer(fast_dev_run=True, max_epochs=10)
    trainer.fit(model, datamodule=dm)
    # trainer.test(model, datamodule=dm)

I'm closing the issue because I don't see any action items. The error message was correct.

awaelchli commented 11 months ago

Please let me know in case you copied this code from somewhere in the docs, then we need to fix it.