mobaidoctor / med-ddpm

GNU General Public License v3.0
133 stars 15 forks source link

Error whith inference from newly trained model #16

Closed LRpz closed 7 months ago

LRpz commented 7 months ago

Hi,

Thank you very much for your codebase, it is very clean! I could successfully train your model on the whole_head dataset that your provided using your 'train.py' script.

Although, running inference ('sample.py') using your pretrained model 'model_128.pt' does work, I do get an error when trying to load a model that resulted from my training.

model = create_model(input_size, num_channels, num_res_blocks, in_channels=in_channels, out_channels=out_channels).cuda()
diffusion = GaussianDiffusion(
    model,
    image_size = input_size,
    depth_size = depth_size,
    timesteps = 250,   # number of steps
    loss_type = 'L1', 
    with_condition=True,
).cuda()

diffusion.load_state_dict(torch.load(weightfile)['ema'])
print("Model Loaded!")

returns:


RuntimeError Traceback (most recent call last) c:\Users\rappez\Documents\git_codebase\med-ddpm\sample.py in line 11 77 model = create_model(input_size, num_channels, num_res_blocks, in_channels=in_channels, out_channels=out_channels).cuda() 78 diffusion = GaussianDiffusion( 79 model, 80 image_size = input_size, (...) 84 with_condition=True, 85 ).cuda() ---> 87 diffusion.load_state_dict(torch.load(weightfile)['ema']) 88 print("Model Loaded!")

File c:\Users\rappez\Anaconda3\envs\torch_apex\lib\site-packages\torch\nn\modules\module.py:2153, in Module.load_state_dict(self, state_dict, strict, assign) 2148 error_msgs.insert( 2149 0, 'Missing key(s) in state_dict: {}. '.format( 2150 ', '.join(f'"{k}"' for k in missing_keys))) 2152 if len(error_msgs) > 0: -> 2153 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( 2154 self.class.name, "\n\t".join(error_msgs))) 2155 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for GaussianDiffusion: size mismatch for denoise_fn.input_blocks.0.0.weight: copying a param with shape torch.Size([64, 1, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 3, 3, 3, 3]).

mobaidoctor commented 7 months ago

Hi, the error message you're encountering indicates a mismatch in the shape of model parameters. The shape of the weight tensor in the checkpoint you are trying to load is torch.Size([64, 1, 3, 3, 3]), whereas the shape expected by your current model definition in your sample.py is torch.Size([64, 3, 3, 3, 3]). Please update your inference script 'sample.py' to use the same input channel configuration as your training script.