apple / ml-mdm

Train high-quality text-to-image diffusion models in a data & compute efficient manner
https://machinelearning.apple.com/research/matryoshka-diffusion-models
MIT License
444 stars 31 forks source link

Use default_factory for mutable fields #23

Closed luke-carlson closed 2 weeks ago

luke-carlson commented 3 weeks ago
======================================== short test summary info =========================================
ERROR tests/test_configs.py - ValueError: mutable default <class 'ml_mdm.samplers.SamplerConfig'> for field sampler_config is not a...
ERROR tests/test_generate_batch.py - ValueError: mutable default <class 'ml_mdm.samplers.SamplerConfig'> for field sampler_config is not a...
ERROR tests/test_models.py - ValueError: mutable default <class 'ml_mdm.samplers.SamplerConfig'> for field sampler_config is not a...
ERROR tests/test_reader.py - ValueError: mutable default <class 'ml_mdm.samplers.SamplerConfig'> for field sampler_config is not a...
ERROR tests/test_train.py - ValueError: mutable default <class 'ml_mdm.samplers.SamplerConfig'> for field sampler_config is not a...
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! Interrupted: 5 errors during collection !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
=========================================== 5 errors in 20.49s ===========================================

I took a look into it, and the error occurs because some of the dataclasses contain mutable objects, for example:

@dataclass
 class DiffusionConfig:
     sampler_config: samplers.SamplerConfig = field(
        default=samplers.SamplerConfig(), metadata={"help": "Sampler configuration”}

Instead, they should have looked like this:

 @dataclass
 class DiffusionConfig:
     sampler_config: samplers.SamplerConfig = field(
        default_factory=samplers.SamplerConfig, metadata={"help": "Sampler configuration"}
luke-carlson commented 2 weeks ago

Did you run pytest -m "not gpu" to make sure you skip tests that require a GPU?

bdeanhardt commented 2 weeks ago

Did you run pytest -m "not gpu" to make sure you skip tests that require a GPU?

yes

luke-carlson commented 2 weeks ago

👍 ya in that case let’s definitely fix any errors you now see!

ethanernst11 commented 2 weeks ago

Hi Luke, I am sorry for the confusion. We accidentally missed part of your PR when editing our own code and that is why we were still getting the errors. We fixed it and now pytest is running smoothly.

pedroborgescruz commented 2 weeks ago

Hi Luke! Aaliyah and I were able to get the tests passing on our end. This is what we get after running pytest -m "not gpu":

=============================== test session starts ================================
platform linux -- Python 3.12.3, pytest-8.3.3, pluggy-1.5.0
rootdir: /home/pcruz1/cs91/ml-mdm
configfile: pyproject.toml
plugins: cov-6.0.0, anyio-4.6.2.post1
collected 19 items / 1 deselected / 18 selected                                    

tests/test_configs.py .....                                                  [ 27%]
tests/test_generate_batch.py ..                                              [ 38%]
tests/test_generate_sample.py .                                              [ 44%]
tests/test_imports.py ....                                                   [ 66%]
tests/test_models.py ..                                                      [ 77%]
tests/test_reader.py ...                                                     [ 94%]
tests/test_train.py s                                                        [100%]

---------- coverage: platform linux, python 3.12.3-final-0 -----------
Name                                       Stmts   Miss  Cover
--------------------------------------------------------------
ml_mdm/__init__.py                             1      0   100%
ml_mdm/clis/__init__.py                        0      0   100%
ml_mdm/clis/download_tar_from_index.py       197    178    10%
ml_mdm/clis/generate_batch.py                139     90    35%
ml_mdm/clis/generate_sample.py               225    187    17%
ml_mdm/clis/run_torchmetrics.py              120    106    12%
ml_mdm/clis/scrape_cc12m.py                   62     62     0%
ml_mdm/clis/train_parallel.py                156    128    18%
ml_mdm/config.py                             127     13    90%
ml_mdm/diffusion.py                          191    122    36%
ml_mdm/distributed.py                         39     28    28%
ml_mdm/generate_html.py                       18     15    17%
ml_mdm/helpers.py                              8      4    50%
ml_mdm/language_models/__init__.py             0      0   100%
ml_mdm/language_models/factory.py             68     16    76%
ml_mdm/language_models/self_attention.py       4      4     0%
ml_mdm/language_models/tokenizer.py          118     73    38%
ml_mdm/language_models/transformer.py          4      4     0%
ml_mdm/lr_scaler.py                           18      9    50%
ml_mdm/models/__init__.py                      1      0   100%
ml_mdm/models/model_ema.py                    41     23    44%
ml_mdm/models/nested_unet.py                 114     51    55%
ml_mdm/models/unet.py                        506    277    45%
ml_mdm/reader.py                             125     36    71%
ml_mdm/s3_helpers.py                          56     43    23%
ml_mdm/samplers.py                           354    250    29%
ml_mdm/trainer.py                             52     48     8%
ml_mdm/utils/__init__.py                       0      0   100%
ml_mdm/utils/fix_old_checkpoints.py           10      7    30%
ml_mdm/utils/simple_logger.py                 85     54    36%
--------------------------------------------------------------
TOTAL                                       2839   1828    36%

=================== 17 passed, 1 skipped, 1 deselected in 32.75s ===================
luke-carlson commented 2 weeks ago

Great, now that you have the tests working - feel free to dive in and make any changes there. Eg increasing the code coverage from 36%, adding doctrines to unit tests, improving test case logic etc