mlcommons / training

Reference implementations of MLPerf™ training benchmarks
https://mlcommons.org/en/groups/training
Apache License 2.0
1.62k stars 560 forks source link

pytorch_lightning no longer uses add_argparse_args #772

Open mahmoodn opened 1 month ago

mahmoodn commented 1 month ago

It seems that there is a problem with add_argparse_args in stable_diffusion when I run the training command:

$ python main.py -m train --ckpt checkpoints/sd/512-base-ema.ckpt --logdir . -b configs/train_01x08x08.yaml
:::MLLOG {"namespace": "", "time_ms": 1729239360007, "event_type": "POINT_IN_TIME", "key": "submission_benchmark", "value": "stable_diffusion", "metadata": {"file": "mlperf_logging_utils.py", "lineno": 72}}
:::MLLOG {"namespace": "", "time_ms": 1729239361128, "event_type": "POINT_IN_TIME", "key": "submission_division", "value": "closed", "metadata": {"file": "mlperf_logging_utils.py", "lineno": 73}}
:::MLLOG {"namespace": "", "time_ms": 1729239361128, "event_type": "POINT_IN_TIME", "key": "submission_org", "value": "reference_implementation", "metadata": {"file": "mlperf_logging_utils.py", "lineno": 74}}
:::MLLOG {"namespace": "", "time_ms": 1729239361128, "event_type": "POINT_IN_TIME", "key": "submission_platform", "value": "DGX-A100", "metadata": {"file": "mlperf_logging_utils.py", "lineno": 75}}
:::MLLOG {"namespace": "", "time_ms": 1729239361129, "event_type": "POINT_IN_TIME", "key": "submission_poc_name", "value": "Ahmad Kiswani", "metadata": {"file": "mlperf_logging_utils.py", "lineno": 76}}
:::MLLOG {"namespace": "", "time_ms": 1729239361129, "event_type": "POINT_IN_TIME", "key": "submission_poc_email", "value": "akiswani@nvidia.com", "metadata": {"file": "mlperf_logging_utils.py", "lineno": 77}}
:::MLLOG {"namespace": "", "time_ms": 1729239361129, "event_type": "POINT_IN_TIME", "key": "submission_status", "value": "onprem", "metadata": {"file": "mlperf_logging_utils.py", "lineno": 78}}
:::MLLOG {"namespace": "", "time_ms": 1729239361129, "event_type": "INTERVAL_START", "key": "init_start", "value": null, "metadata": {"file": "main.py", "lineno": 383}}
Traceback (most recent call last):
  File "/scratch/mahmood/training/stable_diffusion/main.py", line 396, in <module>
    parser = Trainer.add_argparse_args(parser)
             ^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: type object 'Trainer' has no attribute 'add_argparse_args'

When I looked at the code, I see two possible packages for trainer:

try:
    from lightning.pytorch import seed_everything
    from lightning.pytorch.callbacks import Callback
    from lightning.pytorch.trainer import Trainer
    from lightning.pytorch.utilities import rank_zero_info
    LIGHTNING_PACK_NAME = "lightning.pytorch."
except:
    from pytorch_lightning import seed_everything
    from pytorch_lightning.callbacks import Callback
    from pytorch_lightning.trainer import Trainer
    from pytorch_lightning.utilities import rank_zero_info
    LIGHTNING_PACK_NAME = "pytorch_lightning."

So, it is either pytorch_lightning or lightning. I have installed the following packages:

$ conda list | grep lightning
lightning                 2.4.0              pyhd8ed1ab_0    conda-forge
lightning-utilities       0.11.8             pyhd8ed1ab_0    conda-forge
pytorch-lightning         2.4.0              pyhd8ed1ab_0    conda-forge

When I print the package name I see the used one is lightning.pytorch and it seems that in newer versions, they have removed add_argparse_args. Any idea on how to fix that without messing up the packages?