MedMNIST / experiments

Codebase for reproducible benchmarking experiments in MedMNIST v2
Apache License 2.0
44 stars 18 forks source link

MedMNIST3D load model and inference #2

Closed javalisiddu closed 2 years ago

javalisiddu commented 2 years ago

I have trained the model using MedMNIST3D model. However, I am not able to load and inference for a single image. The below error is what I am getting.

RuntimeError Traceback (most recent call last) C:\Users\SMALLA~1.JAV\AppData\Local\Temp/ipykernel_17416/966713097.py in 1 model = ResNet18(3,11) 2 model3d = model_to_syncbn(Conv3dConverter(model, i3d_repeat_axis=-3)) ----> 3 model3d.load_state_dict(torch.load('best_model.pth', map_location='cpu')['net'])

~.conda\envs\segment\lib\site-packages\acsconv\converters\conv3d_converter.py in load_state_dict(self, state_dict, strict, i3d_repeat_axis) 44 return load_state_dict_from_2d_to_i3d(self.model, state_dict, strict, repeat_axis=i3d_repeat_axis) 45 else: ---> 46 return self.model.load_state_dict(state_dict, strict) 47 48

~.conda\envs\segment\lib\site-packages\torch\nn\modules\module.py in load_state_dict(self, state_dict, strict) 1481 if len(error_msgs) > 0: 1482 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( -> 1483 self.class.name, "\n\t".join(error_msgs))) 1484 return _IncompatibleKeys(missing_keys, unexpected_keys) 1485

RuntimeError: Error(s) in loading state_dict for ResNet: size mismatch for conv1.weight: copying a param with shape torch.Size([64, 1, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 3, 3, 3, 3]).

Capture

duducheng commented 2 years ago

Hi,

It doesn't look like an error of our code: it seems that you use different input / checkpoint from that of dataloader in the script.

The model uses 3-channel or 1-channel input, which can be controlled by as_rgb. For grey-scale ones, we simply copy 1 channel into 3.

You should check your checkpoint. You can also refer to the dataloader and data transform for the detail.

Jiancheng