facebookresearch / vissl

VISSL is FAIR's library of extensible, modular and scalable components for SOTA Self-Supervised Learning with images.
https://vissl.ai
MIT License
3.25k stars 332 forks source link

How to conver VISSL to pytorch ot torchscript #539

Open dereyly opened 2 years ago

dereyly commented 2 years ago

Take pretrained VISSL model (like SEER) to finetune model in own pipline in pure pytorch There is a option to convert from VISSL format to pure pytorch?

QuentinDuval commented 2 years ago

Hi @dereyly,

We have a script named convert_vissl_to_torchvision.py under extra_scripts that should do this.

You can use it like so:

python extra_scripts/convert_vissl_to_torchvision.py --model_url_or_file /path/to/vissl/model_final_checkpoint_phase299.torch --output_dir /output/folder --output_name torchvision.pth

Please tell me if that works out for you!

Thank you, Quentin

TopTea1 commented 2 years ago

Hi @QuentinDuval, I have tried to do it this way, but I'm getting this error, using RegNet from torchvision :

---------------------------------------------------------------------------

RuntimeError                              Traceback (most recent call last)

[<ipython-input-26-0ebdee598139>](https://localhost:8080/#) in <module>()
----> 1 model.load_state_dict(torch.load("converted_vissl_converted.torch", map_location=torch.device('cpu')))

[/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in load_state_dict(self, state_dict, strict)
   1603         if len(error_msgs) > 0:
   1604             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1605                                self.__class__.__name__, "\n\t".join(error_msgs)))
   1606         return _IncompatibleKeys(missing_keys, unexpected_keys)
   1607 

RuntimeError: Error(s) in loading state_dict for RegNet:
    Missing key(s) in state_dict: "stem.0.weight", "stem.1.weight", "stem.1.bias", "stem.1.running_mean", "stem.1.running_var", "trunk_output.block1.block1-0.proj.0.weight", "trunk_output.block1.block1-0.proj.1.weight", "trunk_output.block1.block1-0.proj.1.bias", "trunk_output.block1.block1-0.proj.1.running_mean", "trunk_output.block1.block1-0.proj.1.running_var", "trunk_output.block1.block1-0.f.a.0.weight", "trunk_output.block1.block1-0.f.a.1.weight", "trunk_output.block1.block1-0.f.a.1.bias", "trunk_output.block1.block1-0.f.a.1.running_mean", "trunk_output.block1.block1-0.f.a.1.running_var", "trunk_output.block1.block1-0.f.b.0.weight", "trunk_output.block1.block1-0.f.b.1.weight", "trunk_output.block1.block1-0.f.b.1.bias", "trunk_output.block1.block1-0.f.b.1.running_mean", "trunk_output.block1.block1-0.f.b.1.running_var", "trunk_output.block1.block1-0.f.se.fc1.weight", "trunk_output.block1.block1-0.f.se.fc1.bias", "trunk_output.block1.block1-0.f.se.fc2.weight", "trunk_output.block1.block1-0.f.se.fc2.bias", "trunk_output.block1.block1-0.f.c.0.weight", "trunk_output.block1.block1-0.f.c.1.weight", "trunk_output.block1.block1-0.f.c.1.bias", "trunk_output.block1.block1-0.f.c.1.running_mean", "trunk_output.block1.block1-0.f.c.1.running_var", "trunk_output.block1.block1-1.f.a.0.weight", "trunk_output.block1.block1-1.f.a.1.weight", "trunk_output.block1.block1-1.f.a.1.bias", "trunk_output.block1.block1-1.f.a.1.running_mean", "trunk_output.block1.block1-1.f.a.1.running_var", "trunk_ou...
    Unexpected key(s) in state_dict: "conv1.stem.0.weight", "conv1.stem.1.weight", "conv1.stem.1.bias", "conv1.stem.1.running_mean", "conv1.stem.1.running_var", "conv1.stem.1.num_batches_tracked", "res2.block1-0.proj.weight", "res2.block1-0.bn.weight", "res2.block1-0.bn.bias", "res2.block1-0.bn.running_mean", "res2.block1-0.bn.running_var", "res2.block1-0.bn.num_batches_tracked", "res2.block1-0.f.a.0.weight", "res2.block1-0.f.a.1.weight", "res2.block1-0.f.a.1.bias", "res2.block1-0.f.a.1.running_mean", "res2.block1-0.f.a.1.running_var", "res2.block1-0.f.a.1.num_batches_tracked", "res2.block1-0.f.b.0.weight", "res2.block1-0.f.b.1.weight", "res2.block1-0.f.b.1.bias", "res2.block1-0.f.b.1.running_mean", "res2.block1-0.f.b.1.running_var", "res2.block1-0.f.b.1.num_batches_tracked", "res2.block1-0.f.se.excitation.0.weight", "res2.block1-0.f.se.excitation.0.bias", "res2.block1-0.f.se.excitation.2.weight", "res2.block1-0.f.se.excitation.2.bias", "res2.block1-0.f.c.weight", "res2.block1-0.f.final_bn.weight", "res2.block1-0.f.final_bn.bias", "res2.block1-0.f.final_bn.running_mean", "res2.block1-0.f.final_bn.running_var", "res2.block1-0.f.final_bn.num_batches_tracked", "res2.block1-1.f.a.0.weight", "res2.block1-1.f.a.1.weight", "res2.block1-1.f.a.1.bias", "res2.block1-1.f.a.1.running_mean", "res2.block1-1.f.a.1.running_var", "res2.block1-1.f.a.1.num_batches_tracked", "res2.block1-1.f.b.0.weight", "res2.block1-1.f.b.1.weight", "res2.block1-1.f.b.1.bias", "res2.block1-1.f.b.1.running_mean"...
i7p9h9 commented 2 years ago

Hi, I met the same issue, I have resolved it in some dummer way, but it's work :)

import torch
from torchvision.models import regnet

model = regnet.regnet_y_32gf()
state_seer = torch.load("path/to/converted/model/converted_vissl_seer.torch", map_location="cpu")
state_vision = model.state_dict()

keys_seer = list(state_seer.keys())
keys_vision = list(state_vision.keys())

## it's ordereddict - so we can try align weights by order, not name
vision_to_seer_key = dict()
for n in range(max(len(state_seer), len(state_seer))):
    try:
        _key_seer = keys_seer[n]
    except:
        _key_seer = None

    try:
        _key_vision = keys_vision[n]
    except:
        _key_vision = None

    vision_to_seer_key[_key_vision] = _key_seer

for _key_vision_item in keys_vision:
    try:
        _key_seer_item = vision_to_seer_key[_key_vision_item]
        state_vision[_key_vision_item].copy_(state_seer[_key_seer_item])
    except Exception as e:
        print("=" * 32)
        print("can't load weights for {}".format(_key_vision_item))
        print(e)
        print("=" * 32)
yyy-Emily commented 1 year ago

Hi, I want to know how to convert .torch to .pth, I have tried this:

python anaconda3/lib/python3.9/site-packages/extra_scripts/convert_vissl_to_torchvision.py --model_url_or_file simclr_class/model_final_checkpoint_phase799.torch --output_dir simclr_class/ --output_name torchvision.pth

but it did not work. Could you tell me how to do this task? Thank you.

yyy-Emily commented 1 year ago

Besides, it outputs .pth.torch, I have tried to remove .torch, but it did not work,too

yugambahl commented 12 months ago

any one able to convert successfully to .pth, I am also getting .pth.torch as output

yugambahl commented 9 months ago

you can refer to this huggingface implementation of seer10B if it helps!

https://huggingface.co/facebook/regnet-y-10b-seer