Closed javalisiddu closed 2 years ago
Hi,
It doesn't look like an error of our code: it seems that you use different input / checkpoint from that of dataloader in the script.
The model uses 3-channel or 1-channel input, which can be controlled by as_rgb
. For grey-scale ones, we simply copy 1 channel into 3.
You should check your checkpoint. You can also refer to the dataloader and data transform for the detail.
Jiancheng
I have trained the model using MedMNIST3D model. However, I am not able to load and inference for a single image. The below error is what I am getting.
RuntimeError Traceback (most recent call last) C:\Users\SMALLA~1.JAV\AppData\Local\Temp/ipykernel_17416/966713097.py in
1 model = ResNet18(3,11)
2 model3d = model_to_syncbn(Conv3dConverter(model, i3d_repeat_axis=-3))
----> 3 model3d.load_state_dict(torch.load('best_model.pth', map_location='cpu')['net'])
~.conda\envs\segment\lib\site-packages\acsconv\converters\conv3d_converter.py in load_state_dict(self, state_dict, strict, i3d_repeat_axis) 44 return load_state_dict_from_2d_to_i3d(self.model, state_dict, strict, repeat_axis=i3d_repeat_axis) 45 else: ---> 46 return self.model.load_state_dict(state_dict, strict) 47 48
~.conda\envs\segment\lib\site-packages\torch\nn\modules\module.py in load_state_dict(self, state_dict, strict) 1481 if len(error_msgs) > 0: 1482 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( -> 1483 self.class.name, "\n\t".join(error_msgs))) 1484 return _IncompatibleKeys(missing_keys, unexpected_keys) 1485
RuntimeError: Error(s) in loading state_dict for ResNet: size mismatch for conv1.weight: copying a param with shape torch.Size([64, 1, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 3, 3, 3, 3]).