Open lsc64 opened 3 months ago
Hi @lsc64
As I explained on the forum, if you use data module then the methods on the LightningModule won't be called. This is by design and expected. If you have a dependency on model-specific attributes, then LightningDataModule is not a good choice. It should only be used if you can decouple the data definitions from your model.
In your case, implement the LightningModule methods instead:
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(28, 2)
def forward(self, batch):
x, y = batch
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("valid_loss", loss)
def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("test_loss", loss)
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def prepare_data(self):
# download
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage: str):
# Assign train/val datasets for use in dataloaders
if stage == "fit":
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(
mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
)
# Assign test dataset for use in dataloader(s)
if stage == "test":
self.mnist_test = MNIST(
self.data_dir, train=False, transform=self.transform
)
if stage == "predict":
self.mnist_predict = MNIST(
self.data_dir, train=False, transform=self.transform
)
def train_dataloader(self) -> Any:
# wrap using this function otherwise
return wrap_data_loader(
data_loader=dl,
max_batch_size=self.batch_size_physical,
optimizer=self.optimizer,
)
hi @awaelchli
This is by design and expected
Should we document this somewhere? Right now the docs state:
LightningDataModule
Use the train_dataloader() method to generate the training dataloader(s). Usually you just wrap the dataset you defined in setup. This is the dataloader that the Trainer fit() method uses.
LightningModule
An iterable or collection of iterables specifying training samples.
But both can be used by trainer.fit
, it's just that one overwrites the other.
I've solved my problem by storing the optimizer and dealing with this outside of lightning, for anyone else having this problem:
train_dataloader = dm.train_dataloader()
model.configure_optimizers(train_dataloader)
wrapped_dl = wrap_data_loader(
data_loader=train_dataloader, max_batch_size=4096, optimizer=model.optimizer
)
trainer.fit(
model, train_dataloaders=wrapped_dl, val_dataloaders=dm.val_dataloader()
)
You just have to make sure the optimizer doesn't get configured again in your LightningModule
def configure_optimizers(self, data_loader=None):
if hasattr(self, "optimizer"):
return self.optimizer
Bug description
The hook
train_dataloader
ofLightningModule
is not being called fromTrainer.fit
. I need to put code there, that changes the dataloader and requires access to the optimizers, as followsWhat version are you seeing the problem on?
v2.1, v2.2
How to reproduce the bug
run the following code. It should print
Hello from train_dataloader in the LightningModule
if the function is being called.Error messages and logs
Environment
Current environment
* CUDA: - GPU: - NVIDIA A100-SXM4-40GB - NVIDIA A100-SXM4-40GB - NVIDIA A100-SXM4-40GB - NVIDIA A100-SXM4-40GB - NVIDIA A100-SXM4-40GB - NVIDIA A100-SXM4-40GB - NVIDIA A100-SXM4-40GB - NVIDIA A100-SXM4-40GB - available: True - version: 12.1 * Lightning: - lightning: 2.2.1 - lightning-bolts: 0.7.0 - lightning-utilities: 0.11.0 - pytorch-lightning: 2.2.1 - torch: 2.2.1 - torchaudio: 2.2.1 - torchmetrics: 1.3.2 - torchvision: 0.17.1 * Packages: - absl-py: 2.1.0 - aiohttp: 3.9.3 - aiohttp-cors: 0.7.0 - aiosignal: 1.3.1 - alembic: 1.13.1 - aniso8601: 9.0.1 - annotated-types: 0.6.0 - antlr4-python3-runtime: 4.9.3 - anyio: 4.3.0 - archspec: 0.2.2 - argon2-cffi: 23.1.0 - argon2-cffi-bindings: 21.2.0 - arrow: 1.3.0 - asciitree: 0.3.3 - asttokens: 2.4.1 - async-lru: 2.0.4 - attrs: 23.2.0 - autodp: 0.2.3.1 - babel: 2.14.0 - bcrypt: 4.1.2 - beautifulsoup4: 4.12.3 - bleach: 6.1.0 - blessed: 1.19.1 - blinker: 1.7.0 - boltons: 23.1.1 - brotli: 1.1.0 - cached-property: 1.5.2 - cachetools: 5.3.3 - certifi: 2024.2.2 - cffi: 1.16.0 - cfgv: 3.4.0 - chardet: 5.2.0 - charset-normalizer: 3.3.2 - chex: 0.1.85 - click: 8.1.7 - cloudpickle: 3.0.0 - colorama: 0.4.6 - colorful: 0.5.6 - colorlog: 6.8.2 - comm: 0.2.2 - conda: 24.1.2 - conda-build: 24.1.2 - conda-index: 0.4.0 - conda-libmamba-solver: 23.12.0 - conda-package-handling: 2.2.0 - conda-package-streaming: 0.9.0 - contourpy: 1.2.0 - cryptography: 42.0.5 - cycler: 0.12.1 - dask: 2024.2.1 - debugpy: 1.8.1 - decorator: 5.1.1 - defusedxml: 0.7.1 - diffprivlib: 0.6.4 - distlib: 0.3.8 - distro: 1.8.0 - dm-tree: 0.1.8 - docker: 7.0.0 - dp-learning-ff: 0.0.9.dev23+g5b7d4b5.d20240319 - entrypoints: 0.4 - equinox: 0.11.3 - etils: 1.7.0 - exceptiongroup: 1.2.0 - executing: 2.0.1 - fasteners: 0.17.3 - fastjsonschema: 2.19.1 - filelock: 3.13.1 - flask: 3.0.2 - flax: 0.8.2 - fonttools: 4.49.0 - fqdn: 1.5.1 - frozenlist: 1.4.1 - fsspec: 2024.3.1 - gast: 0.5.4 - gitdb: 4.0.11 - gitpython: 3.1.42 - gmpy2: 2.1.2 - google-api-core: 2.17.1 - google-auth: 2.28.1 - google-vizier: 0.1.15 - googleapis-common-protos: 1.62.0 - gpustat: 1.1.1 - graphene: 3.3 - graphql-core: 3.2.3 - graphql-relay: 3.2.0 - greenlet: 3.0.3 - grpcio: 1.62.1 - grpcio-tools: 1.62.1 - gunicorn: 21.2.0 - h11: 0.14.0 - h5py: 3.10.0 - httpcore: 1.0.4 - httpx: 0.27.0 - huggingface-hub: 0.21.4 - hydra-core: 1.3.2 - identify: 2.5.35 - idna: 3.6 - importlib-metadata: 7.0.1 - importlib-resources: 6.1.2 - ipykernel: 6.29.3 - ipython: 8.22.2 - ipywidgets: 8.1.2 - isoduration: 20.11.0 - itsdangerous: 2.1.2 - jax: 0.4.25 - jaxlib: 0.4.25 - jaxopt: 0.8.3 - jaxtyping: 0.2.28 - jedi: 0.19.1 - jinja2: 3.1.3 - joblib: 1.3.2 - json5: 0.9.24 - jsonpatch: 1.33 - jsonpointer: 2.4 - jsonschema: 4.21.1 - jsonschema-specifications: 2023.12.1 - jupyter-client: 8.6.0 - jupyter-core: 5.7.1 - jupyter-events: 0.9.0 - jupyter-lsp: 2.2.4 - jupyter-server: 2.13.0 - jupyter-server-mathjax: 0.2.6 - jupyter-server-terminals: 0.5.2 - jupyterlab: 4.1.5 - jupyterlab-git: 0.50.0 - jupyterlab-pygments: 0.3.0 - jupyterlab-server: 2.25.4 - jupyterlab-widgets: 3.0.10 - kiwisolver: 1.4.5 - libarchive-c: 5.0 - libmambapy: 1.5.7 - lightning: 2.2.1 - lightning-bolts: 0.7.0 - lightning-utilities: 0.11.0 - locket: 1.0.0 - mako: 1.3.2 - mamba: 1.5.7 - markdown: 3.5.2 - markdown-it-py: 3.0.0 - markupsafe: 2.1.5 - matplotlib: 3.8.3 - matplotlib-inline: 0.1.6 - mdurl: 0.1.2 - memory-tempfile: 2.2.3 - menuinst: 2.0.2 - mistune: 3.0.2 - ml-dtypes: 0.3.2 - mlflow: 2.11.0 - mlflow-skinny: 2.11.0 - more-itertools: 10.2.0 - mpmath: 1.3.0 - msgpack: 1.0.7 - multidict: 6.0.5 - munkres: 1.1.4 - nbclient: 0.8.0 - nbconvert: 7.16.2 - nbdime: 4.0.1 - nbformat: 5.9.2 - nest-asyncio: 1.6.0 - networkx: 3.2.1 - nodeenv: 1.8.0 - notebook-shim: 0.2.4 - numcodecs: 0.12.1 - numpy: 1.26.4 - nvidia-cublas-cu12: 12.1.3.1 - nvidia-cuda-cupti-cu12: 12.1.105 - nvidia-cuda-nvrtc-cu12: 12.1.105 - nvidia-cuda-runtime-cu12: 12.1.105 - nvidia-cudnn-cu12: 8.9.2.26 - nvidia-cufft-cu12: 11.0.2.54 - nvidia-curand-cu12: 10.3.2.106 - nvidia-cusolver-cu12: 11.4.5.107 - nvidia-cusparse-cu12: 12.1.0.106 - nvidia-ml-py: 12.535.133 - nvidia-nccl-cu12: 2.19.3 - nvidia-nvjitlink-cu12: 12.4.99 - nvidia-nvtx-cu12: 12.1.105 - omegaconf: 2.3.0 - opacus: 1.4.1 - opencensus: 0.11.4 - opencensus-context: 0.1.3 - opt-einsum: 3.3.0 - optax: 0.2.1 - optuna: 3.5.0 - orbax-checkpoint: 0.5.6 - overrides: 7.7.0 - packaging: 24.0 - pandas: 2.2.1 - pandocfilters: 1.5.0 - paramiko: 3.4.0 - parso: 0.8.3 - partd: 1.4.1 - pexpect: 4.9.0 - pickleshare: 0.7.5 - pillow: 9.4.0 - pip: 23.3.2 - pkginfo: 1.10.0 - pkgutil-resolve-name: 1.3.10 - platformdirs: 4.1.0 - pluggy: 1.3.0 - portpicker: 1.6.0 - pre-commit: 3.6.2 - prometheus-client: 0.20.0 - prometheus-flask-exporter: 0.23.0 - prompt-toolkit: 3.0.42 - protobuf: 4.24.4 - psutil: 5.9.8 - psycopg2-binary: 2.9.9 - ptyprocess: 0.7.0 - pure-eval: 0.2.2 - py-spy: 0.3.14 - pyarrow: 15.0.0 - pyasn1: 0.5.1 - pyasn1-modules: 0.3.0 - pycosat: 0.6.6 - pycparser: 2.21 - pydantic: 2.6.3 - pydantic-core: 2.16.3 - pygments: 2.17.2 - pynacl: 1.5.0 - pyparsing: 3.1.2 - pysocks: 1.7.1 - python-dateutil: 2.9.0 - python-dp: 1.1.4 - python-json-logger: 2.0.7 - pytorch-lightning: 2.2.1 - pytz: 2024.1 - pyyaml: 6.0.1 - pyzmq: 25.1.2 - querystring-parser: 1.2.4 - ray: 2.9.3 - referencing: 0.33.0 - regex: 2023.12.25 - requests: 2.31.0 - rfc3339-validator: 0.1.4 - rfc3986-validator: 0.1.1 - rich: 13.7.1 - rpds-py: 0.18.0 - rsa: 4.9 - ruamel.yaml: 0.18.6 - ruamel.yaml.clib: 0.2.8 - ruff: 0.3.3 - safetensors: 0.4.2 - scikit-learn: 1.4.1.post1 - scipy: 1.12.0 - seaborn: 0.13.2 - send2trash: 1.8.2 - setuptools: 68.2.2 - six: 1.16.0 - skorch: 0.15.0 - smart-open: 7.0.1 - smmap: 5.0.0 - sniffio: 1.3.1 - soupsieve: 2.5 - sqlalchemy: 2.0.28 - sqlparse: 0.4.4 - stack-data: 0.6.2 - sympy: 1.12 - tabulate: 0.9.0 - tensorboard: 2.16.2 - tensorboard-data-server: 0.7.0 - tensorstore: 0.1.56 - terminado: 0.18.0 - tfp-nightly: 0.25.0.dev20240318 - threadpoolctl: 3.3.0 - timm: 0.9.16 - tinycss2: 1.2.1 - tokenizers: 0.15.2 - toolz: 0.12.1 - torch: 2.2.1 - torchaudio: 2.2.1 - torchmetrics: 1.3.2 - torchvision: 0.17.1 - tornado: 6.4 - tqdm: 4.66.2 - traitlets: 5.14.1 - transformers: 4.38.2 - triton: 2.2.0 - truststore: 0.8.0 - typeguard: 2.13.3 - types-python-dateutil: 2.8.19.20240106 - typing-extensions: 4.10.0 - typing-utils: 0.1.0 - tzdata: 2024.1 - uri-template: 1.3.0 - urllib3: 2.1.0 - uv: 0.1.22 - virtualenv: 20.25.1 - vit-proto: 0.0.0 - wcwidth: 0.2.13 - webcolors: 1.13 - webencodings: 0.5.1 - websocket-client: 1.7.0 - werkzeug: 3.0.1 - wheel: 0.42.0 - widgetsnbextension: 4.0.10 - wrapt: 1.16.0 - yarl: 1.9.4 - zarr: 2.17.1 - zipp: 3.17.0 - zstandard: 0.22.0 * System: - OS: Linux - architecture: - 64bit - - processor: x86_64 - python: 3.11.8 - release: 5.4.0-173-generic - version: #191-Ubuntu SMP Fri Feb 2 13:55:07 UTC 2024More info
No response
cc @carmocca @awaelchli @borda