KidsWithTokens / MedSegDiff

Medical Image Segmentation with Diffusion Model
MIT License
980 stars 147 forks source link

Unable to execute segementation_sample.py command #114

Closed vis6060 closed 1 year ago

vis6060 commented 1 year ago

I ran the segmentation_train.py command as mentioned in readme.txt file for 100,000 steps. Now, when I run the segmentation_sample.py command as mentioned in readme.txt file, I get the following error. Please help.

raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for UNetModel_newpreview: size mismatch for input_blocks.0.0.weight: copying a param with shape torch.Size([128, 2, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 1, 3, 3]). size mismatch for hwm.conv_blocks_context.0.blocks.0.conv.weight: copying a param with shape torch.Size([32, 1, 3, 3]) from checkpoint, the shape in current model is torch.Size([32, 0, 3, 3]).

earsonlau commented 1 year ago

use saved100000.pt

vis6060 commented 1 year ago

thank you earsonlau for your comment. I was able to resolve the error. I am using the "emasavedmodel_0.9999_100000.pt" file. The mistake I made that for the LiTS dataset of CT images, I have only one file. Thus, in the segmentation_sample.py file, I should have set as args.in_ch = 2. I was initially setting this as 1.

WJunde feel free to close this issue.