facebookresearch / fairseq

Facebook AI Research Sequence-to-Sequence Toolkit written in Python.
MIT License
30.43k stars 6.4k forks source link

Saving wav2vec2 model results in serialization error #3503

Open harveenchadha opened 3 years ago

harveenchadha commented 3 years ago

Hi,

I am trying to save the compiled model using the code below:

w2v = torch.load(model_path)
model = Wav2VecCtc.build_model(args, target_dict)
model.load_state_dict(w2v["model"], strict=True)
torch.save(model, 'test_again.pt' )

But whenever I execute this code, the model is built successfully but it throws an error:

Screenshot 2021-04-24 at 3 11 36 AM

This error I did not encounter in non hydra version of fairseq. Can you please let me know what can be done??

Thanks!

lematt1991 commented 3 years ago

This sounds related to #3482. Can you try the workaround listed there for now?

harveenchadha commented 3 years ago

A workaround suggested in #3482

Add the following code in utils.py (fairseq/fairseq/dataclass/utils.py) line 460:

cfg = OmegaConf.merge(
    OmegaConf.structured(
         MyConfigClassWhichNeedsToBePickled
    ),
    OmegaConf.create(
        OmegaConf.to_yaml(cfg, resolve=True)
    )
)
harveenchadha commented 3 years ago

Actually there is some issue and again this error is popping up again. The code worked as a quick fix for old models only but with new models (trained with hydra) this code is not working!

villmow commented 3 years ago

I receive the same error when training a model with fairseq-hydra-train on multiple GPUs. Single GPU works.

minghao-wu commented 3 years ago

@villmow @harveenchadha Hi there,

Thank you very much for your solution, but it still doesn't work on my code.

I insert your snippt into line 460 at fairseq/fairseq/dataclass/utils.py and my function look like this now:

def merge_with_parent(dc: FairseqDataclass, cfg: FairseqDataclass):
    cfg = OmegaConf.merge(
        OmegaConf.structured(
            dc
        ),
        OmegaConf.create(
            OmegaConf.to_yaml(cfg, resolve=True)
        )
    )
    merged_cfg = OmegaConf.merge(dc, cfg)
    merged_cfg.__dict__["_parent"] = cfg.__dict__["_parent"]
    OmegaConf.set_struct(merged_cfg, True)
    return merged_cfg

Is this the right way to use your snippt?

I am training my custom task using multiple GPUs and encountered a similar error.

More details can be found at https://github.com/pytorch/fairseq/issues/3634

Thanks again.

harveenchadha commented 3 years ago

Can you check if this version of fairseq works ?

https://github.com/Open-Speech-EkStep/fairseq/tree/v2-hydra

duj12 commented 2 years ago

I meet the same problem in Hubert model saving. To produce, one can use the opensource hubert_base model, and run:

import fairseq, torch
model,_,_=fairseq.checkpoint_utils.load_model_ensemble_and_task(['hubert_base_ls960.pt'])
torch.save(model[0], 'hubert_model.pt')

Error message is like:

Traceback (most recent call last):
  File "/home/xmov/miniconda3/envs/fairseq/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3553, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-7-db58c52b5dcc>", line 1, in <module>
    torch.save(model[0], '/data/megastore/Projects/DuJing/deployment/hubert/hubert_model.pt')
  File "/home/xmov/miniconda3/envs/fairseq/lib/python3.7/site-packages/torch/serialization.py", line 380, in save
    _save(obj, opened_zipfile, pickle_module, pickle_protocol)
  File "/home/xmov/miniconda3/envs/fairseq/lib/python3.7/site-packages/torch/serialization.py", line 589, in _save
    pickler.dump(obj)
_pickle.PicklingError: Can't pickle <enum 'Choices'>: attribute lookup Choices on fairseq.dataclass.constants failed

Anyone has solution?

bellagodiva commented 1 year ago

hi! my model contains fairseq's Hubert model and I also encountered the same error message when trying to save the model at best validation result with torch.save(model, 'name.pt'). I also tried the solution mentioned in https://github.com/facebookresearch/fairseq/issues/3482 but to no avail. So instead of saving the entire model, I save and load the model's state_dict() and it works!

torch.save(model.state_dict(), 'name.pt') model = Model() model.load_state_dict(torch.load('name.pt'))

reference: https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict

hope this might be helpful for those who are trying to save their model with similar motivation. i am training with single gpu.