WaterKnight1998 / SemTorch

Apache License 2.0
162 stars 15 forks source link

Can't load HRNet Segmentation weights from PTH file #7

Open Andrevmatias opened 3 years ago

Andrevmatias commented 3 years ago

With get_segmentation_learner(architecture_name='hrnet', backbone_name='hrnet_w18').

Using the following callback to save the models during training: SaveModelCallback(monitor='dice_multi', fname='best_model', with_opt=True) The results of the predictions after loading "best_model.pth" with learner.load are zero-filled masks.

The prediction using the learner right after training are correct.

WaterKnight1998 commented 2 years ago

Hi, sorry, for super late reply. Where you using Windows?

Andrevmatias commented 2 years ago

No, we were using Google Collab Notebooks (Ubuntu 18.04.3 LTS 64-bit).

bossyang commented 1 year ago

w32 & w48 not work, either.

PyTorch 1.12, Python 3.9 on Paperspace

learn = get_segmentation_learner(dls=dls, number_classes=2, segmentation_type="Semantic Segmentation",
                                 architecture_name="hrnet", backbone_name="hrnet_w32",
                                 splitter=segmentron_splitter,
                                 loss_func=CustomLoss(),
                                 metrics=[Dice, foreground_acc, JaccardCoeff],
                                 wd=1e-3).to_fp16()
RuntimeError                              Traceback (most recent call last)
File /usr/local/lib/python3.9/dist-packages/semtorch/models/archs/backbones/build.py:51, in load_backbone_pretrained(model, backbone)
     49     weights_path = download(model_urls[backbone], path=weights_path)
---> 51     msg = model.init_weights(pretrained=weights_path)
     52 else:    

File /usr/local/lib/python3.9/dist-packages/semtorch/models/archs/backbones/hrnet.py:475, in HighResolutionNet.init_weights(self, pretrained)
    474     model_dict.update(pretrained_dict)
--> 475     self.load_state_dict(model_dict)
    476 return "HRNet backbone wieghts loaded"

File /usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py:1604, in Module.load_state_dict(self, state_dict, strict)
   1603 if len(error_msgs) > 0:
-> 1604     raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   1605                        self.__class__.__name__, "\n\t".join(error_msgs)))
   1606 return _IncompatibleKeys(missing_keys, unexpected_keys)
bossyang commented 1 year ago

complete error message. hrnet-w32-error.txt

bossyang commented 1 year ago

After deleting the previous cache file, the notebook can load the hrnet_w32 weights. It seems the PTH cache will always be ~/.cache/torch/checkpoints.