ZhengPeng7 / BiRefNet

[CAAI AIR'24] Bilateral Reference for High-Resolution Dichotomous Image Segmentation
https://www.birefnet.top
MIT License
1.09k stars 84 forks source link

model load mismatch #28

Closed KennyAtDM closed 3 months ago

KennyAtDM commented 4 months ago

In inference.py, how to load different models other then BiRefNet-massive-epoch_240.pth? I keep getting RuntimeError: Error(s) in loading state_dict for BiRefNet, but do not know where to set the models weights correctly

  for weights in weights_lst:
      print(weights.strip('.pth').split('epoch_')[-1])
      if int(weights.strip('.pth').split('epoch_')[-1]) % 1 != 0:
          continue
      print('\tInferencing {}...'.format(weights))
      state_dict = torch.load(weights, map_location='cpu')
      state_dict = check_state_dict(state_dict)
      model.load_state_dict(state_dict)
      model = model.to(device)
ZhengPeng7 commented 4 months ago

Do you mean the script cannot localize the target weights file? If so, remove the if lines and specify the target weights files in weights_lst.

ZhengPeng7 commented 4 months ago

If the error is in loading the weights (key or weights mismatch), plz tell me.

KennyAtDM commented 4 months ago

If the error is in loading the weights (key or weights mismatch), plz tell me.

Yes, the error is in loading weights, when if I remove the if

KennyAtDM commented 4 months ago

I am trying to load BiRefNet-massive-bb_swin_v1_tiny-epoch_235.pth rather than BiRefNet-massive-epoch_240.pth

ZhengPeng7 commented 4 months ago

Hi, @KennyAtDM , I checked the weights loading as below, which seems alright:

截屏2024-06-07 09 04 05

I guess the key is that: if you want to load the weights with a certain backbone, you need to first set the bb in config.py to the corresponding one, so that BiRefNet() could be the right one you want. Have a try it.

Tell me if there is still any problem :)

ZhengPeng7 commented 3 months ago

Re-open it if you still have this problem.

karndeb commented 3 months ago

@ZhengPeng7 I am still having the same issue. I changed the self.bb value in the conf but still not working. Can you help? I am using the single image reference notebook. Here is the code block and error

model = BiRefNet(bb_pretrained=False)
state_dict = torch.load("/content/drive/MyDrive/birefnet-tiny/BiRefNet-massive-bb_swin_v1_tiny-epoch_235.pth", map_location='cpu')
state_dict = check_state_dict(state_dict)
model.load_state_dict(state_dict)
model.eval()
RuntimeError                              Traceback (most recent call last)
[<ipython-input-14-be66c23d1019>](https://localhost:8080/#) in <cell line: 20>()
     18 state_dict = torch.load("/content/drive/MyDrive/birefnet-tiny/BiRefNet-massive-bb_swin_v1_tiny-epoch_235.pth", map_location='cpu')
     19 state_dict = check_state_dict(state_dict)
---> 20 model.load_state_dict(state_dict,  strict=False)
     21 model.eval()
     22 # model = model.to('cuda')

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in load_state_dict(self, state_dict, strict)
   2039 
   2040         if len(error_msgs) > 0:
-> 2041             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   2042                                self.__class__.__name__, "\n\t".join(error_msgs)))
   2043         return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for BiRefNet:
ZhengPeng7 commented 3 months ago

Hi, of course, I can help make sure everything is good; it's my job and my duty.

But the codes you use are exactly the same as mine. Therefore, the problem is most probably related to the environment. Can you create a new env with py=3.9 and torch==2.0.1 as the guideline in README? Only two steps. Tell me if it works in the new env.

MinGiSa commented 3 months ago

@ZhengPeng7 I am still having the same issue. I changed the self.bb value in the conf but still not working. Can you help? I am using the single image reference notebook. Here is the code block and error

model = BiRefNet(bb_pretrained=False)
state_dict = torch.load("/content/drive/MyDrive/birefnet-tiny/BiRefNet-massive-bb_swin_v1_tiny-epoch_235.pth", map_location='cpu')
state_dict = check_state_dict(state_dict)
model.load_state_dict(state_dict)
model.eval()
RuntimeError                              Traceback (most recent call last)
[<ipython-input-14-be66c23d1019>](https://localhost:8080/#) in <cell line: 20>()
     18 state_dict = torch.load("/content/drive/MyDrive/birefnet-tiny/BiRefNet-massive-bb_swin_v1_tiny-epoch_235.pth", map_location='cpu')
     19 state_dict = check_state_dict(state_dict)
---> 20 model.load_state_dict(state_dict,  strict=False)
     21 model.eval()
     22 # model = model.to('cuda')

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in load_state_dict(self, state_dict, strict)
   2039 
   2040         if len(error_msgs) > 0:
-> 2041             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   2042                                self.__class__.__name__, "\n\t".join(error_msgs)))
   2043         return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for BiRefNet:
    self.bb = [
        'vgg16', 'vgg16bn', 'resnet50',         # 0, 1, 2
        'swin_v1_t', 'swin_v1_s',               # 3, 4
        'swin_v1_b', 'swin_v1_l',               # 5-bs9, 6-bs4
        'pvt_v2_b0', 'pvt_v2_b1',               # 7, 8
        'pvt_v2_b2', 'pvt_v2_b5',               # 9-bs10, 10-bs5
    ][3]

// select 'swin_v1_t' from config.py by [3] -> it was [6].

ZhengPeng7 commented 3 months ago

Thanks, @MinGiSa, for the help. Sorry, @karndeb, my mistake, you had the try on the colab, so there's no problem with the env. And I also did the test in the colab just now for single image inference, as shown below; everything is good for loading the version with swin_v1_t as the backbone, as shown below. BTW, did you restart the session to load the updated self.bb in config.py? You can print the Config().bb to test it. 截屏2024-06-17 22 39 35

karndeb commented 3 months ago

@ZhengPeng7 and @MinGiSa . Its working perfectly now. Thanks for the quick response. I had changed the config from 6 to 3 in the self.bb but I think I didnt restart the session. Btw, absolutely brilliant work. Very impressive. This can be closed now.

ZhengPeng7 commented 3 months ago

Great, glad to hear that.