MedicineToken / MedSegDiff

Medical Image Segmentation with Diffusion Model
MIT License
1.06k stars 166 forks source link

Error while loading saved checkpoint in V1 #193

Closed pedropesserl closed 1 week ago

pedropesserl commented 1 month ago

Command used: python scripts/segmentation_train.py --version 1 --data_name ISIC --data_dir $train_data_dir --out_dir $train_output_dir --image_size 256 --num_channels 128 --class_cond False --num_res_blocks 2 --num_heads 1 --learn_sigma True --use_scale_shift_norm False --attention_resolutions 16 --diffusion_steps 1000 --noise_schedule linear --rescale_learned_sigmas False --rescale_timesteps False --lr 1e-4 --batch_size 16 --lr_anneal_steps 100000 --resume_checkpoint ../train_out/savedmodel085000.pt

Error: AttributeError: 'UNetModel_v1preview' object has no attribute 'load_part_state_dict'. Did you mean: '_load_from_state_dict'?

From

Logging to ../train_out
creating data loader...
creating model and diffusion...
training...
resume model
loading model from checkpoint: ../train_out/savedmodel085000.pt...
[rank0]: Traceback (most recent call last):
[rank0]:   File "/home/pesserl/medsegdiff/MedSegDiff/scripts/segmentation_train.py", line 117, in <module>
[rank0]:     main()
[rank0]:   File "/home/pesserl/medsegdiff/MedSegDiff/scripts/segmentation_train.py", line 69, in main
[rank0]:     TrainLoop(
[rank0]:   File "/home/pesserl/medsegdiff/MedSegDiff/./guided_diffusion/train_util.py", line 83, in __init__
[rank0]:     self._load_and_sync_parameters()
[rank0]:   File "/home/pesserl/medsegdiff/MedSegDiff/./guided_diffusion/train_util.py", line 133, in _load_and_sync_parameters
[rank0]:     self.model.load_part_state_dict(
[rank0]:     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/pesserl/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1709, in __getattr__
[rank0]:     raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
[rank0]: AttributeError: 'UNetModel_v1preview' object has no attribute 'load_part_state_dict'. Did you mean: '_load_from_state_dict'?

Suggested fix: copy load_part_state_dict() method from UNetModel_newpreview into UNetModel_v1preview