Open peacekurella opened 1 month ago
I noticed that the drop down menu does not contain 2.3.x as part of the version selection.
Hey @peacekurella can you please provide a code example based on https://github.com/Lightning-AI/pytorch-lightning/blob/master/examples/pytorch/bug_report/bug_report_model.py so we can verify it's not working?
I can do that.
import torch
from typing import Type, TypeVar
from lightning.pytorch import LightningModule
from torch.utils.data import DataLoader, Dataset
from lightning.pytorch.cli import LightningCLI
from lightning import LightningDataModule
from lightning.pytorch.callbacks import ModelCheckpoint
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.add_argument("data.destinationaddressid_vocab_size", default=10)
parser.add_argument("model.destinationaddressid_vocab_size")
parser.add_argument("--ckpt_path_ex", type=str, default = None)
parser.link_arguments(
"data.destinationaddressid_vocab_size",
"model.destinationaddressid_vocab_size",
apply_on="instantiate",
)
def before_instantiate_classes(self) -> None:
if self.config.ckpt_path_ex:
print("restoring from checkpoint")
# we are restoring from a checkpoint
CheckpointModuleInstantiatiorCLI.before_instantiate_classes(self)
class MyDataModule(LightningDataModule):
def __init__(self, destinationaddressid_vocab_size: int = None):
super().__init__()
self.destinationaddressid_vocab_size = destinationaddressid_vocab_size
print(f"The value of destinationaddressid_vocab_size in data module is {destinationaddressid_vocab_size}")
def train_dataloader(self):
return DataLoader(RandomDataset(32, 64), batch_size=2)
def val_dataloader(self):
return DataLoader(RandomDataset(32, 64), batch_size=2)
def predict_dataloader(self):
return DataLoader(RandomDataset(32, 64), batch_size=2)
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
ModuleType = TypeVar("ModuleType")
class CheckpointModuleInstantiatiorCLI:
def __init__(self, cli: LightningCLI):
self.cli = cli
def class_instantiator(self, class_type: Type[ModuleType], *args, **kwargs) -> ModuleType:
if args:
raise ValueError("Unexpected args")
map_location = None if torch.cuda.is_available() else "cpu"
defaults = self.cli.parser.get_defaults()
if class_type == BoringModel:
non_default_kwargs = {k: v for k, v in kwargs.items() if defaults.model.get(k) != v}
return BoringModel.load_from_checkpoint(
self.cli.config.ckpt_path_ex,
map_location=map_location,
**non_default_kwargs,
)
elif class_type == MyDataModule:
non_default_kwargs = {k: v for k, v in kwargs.items() if defaults.data.get(k) != v}
return MyDataModule.load_from_checkpoint(
self.cli.config.ckpt_path_ex,
map_location=map_location,
**non_default_kwargs,
)
else:
raise ValueError("Unexpected class")
@staticmethod
def before_instantiate_classes(cli: LightningCLI) -> None:
instantiator = CheckpointModuleInstantiatiorCLI(cli)
cli.parser.add_instantiator(instantiator.class_instantiator, BoringModel)
cli.parser.add_instantiator(instantiator.class_instantiator, MyDataModule)
class BoringModel(LightningModule):
def __init__(self, destinationaddressid_vocab_size):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
self.destinationaddressid_vocab_size = destinationaddressid_vocab_size
self.save_hyperparameters()
print(f"The value of destinationaddressid_vocab_size in model module is {self.destinationaddressid_vocab_size}")
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("valid_loss", loss)
def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("test_loss", loss)
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def run(args):
cli = MyLightningCLI(
BoringModel,
MyDataModule,
args=args,
trainer_defaults={"callbacks": [ModelCheckpoint(dirpath="ckpts")]},
run=False,
)
cli.trainer.fit(
model=cli.model,
datamodule=cli.datamodule,
ckpt_path=cli.config.ckpt_path_ex if cli.config.ckpt_path_ex else None,
)
if __name__ == "__main__":
run(args=None)
Running with lightning 2.2.5
generate checkpoints python bug_report.py --data.destinationaddressid_vocab_size 15 --trainer.max_epoch=1
. This prints
The value of destinationaddressid_vocab_size in data module is 15
The value of destinationaddressid_vocab_size in model module is 15
load the model from checkpoints python bug_report.py --trainer.max_epoch=2 --ckpt_path_ex ckpts/epoch=0-step=32.ckpt
this prints
The value of destinationaddressid_vocab_size in data module is None
The value of destinationaddressid_vocab_size in model module is 15
Running with lightning 2.3.3
generate checkpoints python bug_report.py --data.destinationaddressid_vocab_size 15 --trainer.max_epoch=1
. This prints
The value of destinationaddressid_vocab_size in data module is 15
The value of destinationaddressid_vocab_size in model module is 15
load the model from checkpoints python bug_report.py --trainer.max_epoch=2 --ckpt_path_ex ckpts/epoch=0-step=32.ckpt
this prints
The value of destinationaddressid_vocab_size in data module is 10
The value of destinationaddressid_vocab_size in model module is 10
@awaelchli added the repro code and scenarios with outputs.
Ok thanks @peacekurella. But the default value is 10, and in the second command you don't pass --data.destinationaddressid_vocab_size 15
. When you resume training, you certainly would need to pass the same configuration. We can't expect that the output is 15 in the second example, if data.destinationaddressid_vocab_size
is not passed.
@awaelchli The way I understand it, save_hyperparameters() is not storing the values for parameters that have been linked previously. This was not the case in lightning 2.2.5. This is a problem when restoring from ckpt files for inference. Typically we try to get all the required HP for inference from the ckpt file itself.
Bug description
When using parser.link_arguments to link fields a & b with apply_on="instantiate", it does not populate the field b when it is accessed later. This was not a problem in lightning 2.2.5 as we are using it currently. However upgrading it to 2.3.x causes field b to not be populated.
What version are you seeing the problem on?
2.3.3
How to reproduce the bug
https://github.com/Lightning-AI/pytorch-lightning/issues/20147#issuecomment-2266215234
Error messages and logs
-
Environment
Current environment
* CUDA: - GPU: None - available: False - version: 12.1 * Lightning: - lightning: 2.2.5 - lightning-utilities: 0.11.2 - pytorch-lightning: 2.3.0 - torch: 2.3.1 - torchmetrics: 1.4.0.post0 * Packages: - aiobotocore: 2.7.0 - aiohttp: 3.9.5 - aioitertools: 0.7.1 - aiosignal: 1.2.0 - alabaster: 0.7.16 - altair: 5.0.1 - anyio: 4.2.0 - appdirs: 1.4.4 - argon2-cffi: 21.3.0 - argon2-cffi-bindings: 21.2.0 - arrow: 1.2.3 - astroid: 2.14.2 - astropy: 6.1.0 - astropy-iers-data: 0.2024.6.3.0.31.14 - asttokens: 2.0.5 - async-lru: 2.0.4 - async-timeout: 4.0.3 - atomicwrites: 1.4.0 - attrs: 23.1.0 - automat: 20.2.0 - autopep8: 2.0.4 - babel: 2.11.0 - bcrypt: 3.2.0 - beautifulsoup4: 4.12.3 - binaryornot: 0.4.4 - black: 24.4.2 - bleach: 4.1.0 - blinker: 1.6.2 - bokeh: 3.4.1 - boto3: 1.34.131 - botocore: 1.34.131 - bottleneck: 1.3.7 - brotli: 1.0.9 - cachetools: 5.3.3 - cattrs: 23.2.3 - certifi: 2024.6.2 - cffi: 1.16.0 - chardet: 4.0.0 - charset-normalizer: 3.3.2 - click: 8.1.7 - cloudpickle: 2.2.1 - colorama: 0.4.6 - colorcet: 3.1.0 - comm: 0.2.1 - constantly: 23.10.4 - contourpy: 1.2.0 - cookiecutter: 2.6.0 - cryptography: 42.0.5 - cssselect: 1.2.0 - cycler: 0.11.0 - cytoolz: 0.12.2 - dask: 2024.5.0 - dask-expr: 1.1.0 - datasets: 2.14.6 - datashader: 0.16.2 - debugpy: 1.6.7 - decorator: 5.1.1 - defusedxml: 0.7.1 - diff-match-patch: 20200713 - dill: 0.3.7 - distributed: 2024.5.0 - docker: 7.1.0 - docstring-parser: 0.16 - docstring-to-markdown: 0.11 - docutils: 0.18.1 - entrypoints: 0.4 - et-xmlfile: 1.1.0 - exceptiongroup: 1.2.0 - executing: 0.8.3 - fastjsonschema: 2.16.2 - filelock: 3.13.1 - flake8: 7.0.0 - flask: 3.0.3 - fonttools: 4.51.0 - frozenlist: 1.4.0 - fsspec: 2023.10.0 - gensim: 4.3.2 - gitdb: 4.0.7 - gitpython: 3.1.37 - gmpy2: 2.1.2 - google-pasta: 0.2.0 - greenlet: 3.0.1 - h5py: 3.11.0 - heapdict: 1.0.1 - holoviews: 1.19.0 - huggingface-hub: 0.23.4 - hvplot: 0.10.0 - hyperlink: 21.0.0 - idna: 3.7 - imagecodecs: 2023.1.23 - imageio: 2.33.1 - imagesize: 1.4.1 - imbalanced-learn: 0.12.3 - importlib-metadata: 6.11.0 - importlib-resources: 6.4.0 - incremental: 22.10.0 - inflection: 0.5.1 - iniconfig: 1.1.1 - intake: 0.7.0 - intervaltree: 3.1.0 - ipykernel: 6.28.0 - ipython: 8.25.0 - ipython-genutils: 0.2.0 - ipywidgets: 7.6.5 - isort: 5.13.2 - itemadapter: 0.3.0 - itemloaders: 1.1.0 - itsdangerous: 2.2.0 - jaraco.classes: 3.2.1 - jedi: 0.18.1 - jeepney: 0.7.1 - jellyfish: 1.0.1 - jinja2: 3.1.4 - jmespath: 1.0.1 - joblib: 1.4.2 - json5: 0.9.6 - jsonargparse: 4.30.0 - jsonschema: 4.19.2 - jsonschema-specifications: 2023.7.1 - jupyter: 1.0.0 - jupyter-client: 8.6.0 - jupyter-console: 6.6.3 - jupyter-core: 5.5.0 - jupyter-events: 0.10.0 - jupyter-lsp: 2.2.0 - jupyter-server: 2.10.0 - jupyter-server-terminals: 0.4.4 - jupyterlab: 4.0.11 - jupyterlab-pygments: 0.1.2 - jupyterlab-server: 2.25.1 - jupyterlab-widgets: 3.0.10 - keyring: 24.3.1 - kiwisolver: 1.4.4 - klon: 2.3.0 - lazy-loader: 0.4 - lazy-object-proxy: 1.10.0 - lckr-jupyterlab-variableinspector: 3.1.0 - lightning: 2.2.5 - lightning-utilities: 0.11.2 - linkify-it-py: 2.0.0 - llvmlite: 0.42.0 - lmdb: 1.4.1 - locket: 1.0.0 - lsprotocol: 2023.0.1 - lxml: 4.9.4 - lxml-stubs: 0.1.1 - lz4: 4.3.2 - markdown: 3.4.1 - markdown-it-py: 2.2.0 - markupsafe: 2.1.3 - matplotlib: 3.8.4 - matplotlib-inline: 0.1.6 - mccabe: 0.7.0 - mdit-py-plugins: 0.3.0 - mdurl: 0.1.0 - mistune: 2.0.4 - mkl-fft: 1.3.8 - mkl-random: 1.2.4 - mkl-service: 2.4.0 - more-itertools: 10.1.0 - mpmath: 1.3.0 - msgpack: 1.0.3 - multidict: 6.0.4 - multipledispatch: 0.6.0 - multiprocess: 0.70.15 - mypy: 1.10.0 - mypy-extensions: 1.0.0 - nbclient: 0.8.0 - nbconvert: 7.10.0 - nbformat: 5.9.2 - nest-asyncio: 1.6.0 - networkx: 3.2.1 - nltk: 3.8.1 - notebook: 7.0.8 - notebook-shim: 0.2.3 - numba: 0.59.1 - numexpr: 2.8.7 - numpy: 1.26.4 - numpydoc: 1.7.0 - nvidia-cublas-cu12: 12.1.3.1 - nvidia-cuda-cupti-cu12: 12.1.105 - nvidia-cuda-nvrtc-cu12: 12.1.105 - nvidia-cuda-runtime-cu12: 12.1.105 - nvidia-cudnn-cu12: 8.9.2.26 - nvidia-cufft-cu12: 11.0.2.54 - nvidia-curand-cu12: 10.3.2.106 - nvidia-cusolver-cu12: 11.4.5.107 - nvidia-cusparse-cu12: 12.1.0.106 - nvidia-nccl-cu12: 2.20.5 - nvidia-nvjitlink-cu12: 12.5.40 - nvidia-nvtx-cu12: 12.1.105 - openpyxl: 3.1.2 - overrides: 7.4.0 - packaging: 23.2 - pandas: 2.2.2 - pandocfilters: 1.5.0 - panel: 1.4.4 - param: 2.1.0 - parsel: 1.8.1 - parso: 0.8.3 - partd: 1.4.1 - pathos: 0.3.1 - pathspec: 0.10.3 - patsy: 0.5.6 - pexpect: 4.8.0 - pickleshare: 0.7.5 - pillow: 10.3.0 - pip: 24.0 - platformdirs: 3.10.0 - plotly: 5.22.0 - pluggy: 1.5.0 - ply: 3.11 - pox: 0.3.4 - ppft: 1.7.6.8 - prometheus-client: 0.14.1 - prompt-toolkit: 3.0.43 - protego: 0.1.16 - protobuf: 3.20.3 - psutil: 5.9.0 - ptyprocess: 0.7.0 - pure-eval: 0.2.2 - py-cpuinfo: 9.0.0 - pyarrow: 14.0.2 - pyasn1: 0.4.8 - pyasn1-modules: 0.2.8 - pycodestyle: 2.11.1 - pycparser: 2.21 - pyct: 0.5.0 - pycurl: 7.45.2 - pydeck: 0.8.0 - pydispatcher: 2.0.5 - pydocstyle: 6.3.0 - pyerfa: 2.0.1.4 - pyflakes: 3.2.0 - pygls: 1.3.1 - pygments: 2.15.1 - pylint: 2.16.2 - pylint-venv: 3.0.3 - pyls-spyder: 0.4.0 - pyodbc: 5.0.1 - pyopenssl: 24.0.0 - pyparsing: 3.0.9 - pyproj: 3.6.1 - pyqt5: 5.15.10 - pyqt5-sip: 12.13.0 - pyqtwebengine: 5.15.6 - pysocks: 1.7.1 - pytest: 8.2.2 - python-dateutil: 2.9.0.post0 - python-json-logger: 2.0.7 - python-lsp-black: 2.0.0 - python-lsp-jsonrpc: 1.1.2 - python-lsp-server: 1.10.0 - python-slugify: 5.0.2 - python-snappy: 0.6.1 - pytoolconfig: 1.2.6 - pytorch-lightning: 2.3.0 - pytz: 2024.1 - pyviz-comms: 3.0.2 - pywavelets: 1.5.0 - pyxdg: 0.27 - pyyaml: 6.0.1 - pyzmq: 25.1.2 - qdarkstyle: 3.2.3 - qstylizer: 0.2.2 - qtawesome: 1.2.2 - qtconsole: 5.5.1 - qtpy: 2.4.1 - queuelib: 1.6.2 - referencing: 0.30.2 - regex: 2023.10.3 - requests: 2.32.3 - requests-file: 1.5.1 - rfc3339-validator: 0.1.4 - rfc3986-validator: 0.1.1 - rich: 13.3.5 - rope: 1.12.0 - rpds-py: 0.10.6 - rtree: 1.0.1 - ruff: 0.4.9 - ruff-lsp: 0.0.53 - s3fs: 2023.10.0 - s3transfer: 0.10.1 - sagemaker: 2.224.0 - schema: 0.7.7 - scikit-image: 0.23.2 - scikit-learn: 1.4.2 - scipy: 1.11.4 - scrapy: 2.11.1 - seaborn: 0.13.2 - secretstorage: 3.3.1 - send2trash: 1.8.2 - service-identity: 18.1.0 - setuptools: 69.5.1 - sip: 6.7.12 - six: 1.16.0 - smart-open: 5.2.1 - smdebug-rulesconfig: 1.0.1 - smmap: 4.0.0 - sniffio: 1.3.0 - snowballstemmer: 2.2.0 - sortedcontainers: 2.4.0 - soupsieve: 2.5 - sphinx: 7.3.7 - sphinxcontrib-applehelp: 1.0.2 - sphinxcontrib-devhelp: 1.0.2 - sphinxcontrib-htmlhelp: 2.0.0 - sphinxcontrib-jsmath: 1.0.1 - sphinxcontrib-qthelp: 1.0.3 - sphinxcontrib-serializinghtml: 1.1.10 - spyder: 5.5.1 - spyder-kernels: 2.5.0 - sqlalchemy: 2.0.30 - stack-data: 0.2.0 - statsmodels: 0.14.2 - streamlit: 1.32.0 - sympy: 1.12 - tables: 3.9.2 - tabulate: 0.9.0 - tblib: 1.7.0 - tenacity: 8.2.2 - tensorboardx: 2.6.2.2 - terminado: 0.17.1 - text-unidecode: 1.3 - textdistance: 4.2.1 - threadpoolctl: 2.2.0 - three-merge: 0.1.1 - tifffile: 2023.4.12 - tinycss2: 1.2.1 - tldextract: 3.2.0 - toml: 0.10.2 - tomli: 2.0.1 - tomlkit: 0.11.1 - toolz: 0.12.0 - torch: 2.3.1 - torchmetrics: 1.4.0.post0 - tornado: 6.4.1 - tqdm: 4.66.4 - traitlets: 5.14.3 - triton: 2.3.1 - twisted: 23.10.0 - typeshed-client: 2.5.1 - typing-extensions: 4.11.0 - tzdata: 2023.3 - uc-micro-py: 1.0.1 - ujson: 5.10.0 - unicodedata2: 15.1.0 - unidecode: 1.2.0 - urllib3: 2.0.7 - w3lib: 2.1.2 - watchdog: 4.0.1 - wcwidth: 0.2.5 - webencodings: 0.5.1 - websocket-client: 1.8.0 - werkzeug: 3.0.3 - whatthepatch: 1.0.2 - wheel: 0.43.0 - widgetsnbextension: 3.5.2 - wrapt: 1.14.1 - wurlitzer: 3.0.2 - xarray: 2023.6.0 - xxhash: 3.4.1 - xyzservices: 2022.9.0 - yapf: 0.40.2 - yarl: 1.9.3 - zict: 3.0.0 - zipp: 3.17.0 - zope.interface: 5.4.0 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.10.13 - release: 5.10.220-188.869.amzn2int.x86_64 - version: #1 SMP Wed Jul 17 14:39:49 UTC 2024More info
No response
cc @carmocca @mauvilsa